aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
committerLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
commitd1abc50ee9d4fb72efc646e17decafea741caa34 (patch)
treee4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src
parentAllow requests with invalid content-types unless specified in command line or... (diff)
parentupdated chunk–block analyser (#818) (diff)
downloadzen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz
zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src')
-rw-r--r--src/zen.icobin0 -> 12957 bytes
-rw-r--r--src/zen/authutils.cpp80
-rw-r--r--src/zen/cmds/admin_cmd.h40
-rw-r--r--src/zen/cmds/bench_cmd.h5
-rw-r--r--src/zen/cmds/builds_cmd.cpp217
-rw-r--r--src/zen/cmds/builds_cmd.h2
-rw-r--r--src/zen/cmds/cache_cmd.h20
-rw-r--r--src/zen/cmds/copy_cmd.h5
-rw-r--r--src/zen/cmds/dedup_cmd.h5
-rw-r--r--src/zen/cmds/exec_cmd.cpp1374
-rw-r--r--src/zen/cmds/exec_cmd.h101
-rw-r--r--src/zen/cmds/info_cmd.h5
-rw-r--r--src/zen/cmds/print_cmd.cpp4
-rw-r--r--src/zen/cmds/print_cmd.h10
-rw-r--r--src/zen/cmds/projectstore_cmd.cpp68
-rw-r--r--src/zen/cmds/projectstore_cmd.h60
-rw-r--r--src/zen/cmds/rpcreplay_cmd.h15
-rw-r--r--src/zen/cmds/run_cmd.h5
-rw-r--r--src/zen/cmds/serve_cmd.h5
-rw-r--r--src/zen/cmds/status_cmd.h5
-rw-r--r--src/zen/cmds/top_cmd.h10
-rw-r--r--src/zen/cmds/trace_cmd.h7
-rw-r--r--src/zen/cmds/ui_cmd.cpp236
-rw-r--r--src/zen/cmds/ui_cmd.h32
-rw-r--r--src/zen/cmds/up_cmd.h15
-rw-r--r--src/zen/cmds/vfs_cmd.h5
-rw-r--r--src/zen/cmds/wipe_cmd.cpp16
-rw-r--r--src/zen/cmds/workspaces_cmd.cpp4
-rw-r--r--src/zen/progressbar.cpp73
-rw-r--r--src/zen/progressbar.h1
-rw-r--r--src/zen/xmake.lua5
-rw-r--r--src/zen/zen.cpp232
-rw-r--r--src/zen/zen.h42
-rw-r--r--src/zen/zen.rc2
-rw-r--r--src/zenbase/include/zenbase/refcount.h3
-rw-r--r--src/zencompute-test/xmake.lua8
-rw-r--r--src/zencompute-test/zencompute-test.cpp16
-rw-r--r--src/zencompute/CLAUDE.md232
-rw-r--r--src/zencompute/cloudmetadata.cpp1014
-rw-r--r--src/zencompute/computeservice.cpp2236
-rw-r--r--src/zencompute/httpcomputeservice.cpp1643
-rw-r--r--src/zencompute/httporchestrator.cpp650
-rw-r--r--src/zencompute/include/zencompute/cloudmetadata.h151
-rw-r--r--src/zencompute/include/zencompute/computeservice.h262
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h54
-rw-r--r--src/zencompute/include/zencompute/httporchestrator.h101
-rw-r--r--src/zencompute/include/zencompute/mockimds.h102
-rw-r--r--src/zencompute/include/zencompute/orchestratorservice.h177
-rw-r--r--src/zencompute/include/zencompute/recordingreader.h129
-rw-r--r--src/zencompute/include/zencompute/zencompute.h15
-rw-r--r--src/zencompute/orchestratorservice.cpp710
-rw-r--r--src/zencompute/recording/actionrecorder.cpp258
-rw-r--r--src/zencompute/recording/actionrecorder.h91
-rw-r--r--src/zencompute/recording/recordingreader.cpp335
-rw-r--r--src/zencompute/runners/deferreddeleter.cpp340
-rw-r--r--src/zencompute/runners/deferreddeleter.h68
-rw-r--r--src/zencompute/runners/functionrunner.cpp365
-rw-r--r--src/zencompute/runners/functionrunner.h214
-rw-r--r--src/zencompute/runners/linuxrunner.cpp734
-rw-r--r--src/zencompute/runners/linuxrunner.h44
-rw-r--r--src/zencompute/runners/localrunner.cpp674
-rw-r--r--src/zencompute/runners/localrunner.h138
-rw-r--r--src/zencompute/runners/macrunner.cpp491
-rw-r--r--src/zencompute/runners/macrunner.h43
-rw-r--r--src/zencompute/runners/remotehttprunner.cpp618
-rw-r--r--src/zencompute/runners/remotehttprunner.h100
-rw-r--r--src/zencompute/runners/windowsrunner.cpp460
-rw-r--r--src/zencompute/runners/windowsrunner.h53
-rw-r--r--src/zencompute/runners/winerunner.cpp237
-rw-r--r--src/zencompute/runners/winerunner.h37
-rw-r--r--src/zencompute/testing/mockimds.cpp205
-rw-r--r--src/zencompute/timeline/workertimeline.cpp430
-rw-r--r--src/zencompute/timeline/workertimeline.h169
-rw-r--r--src/zencompute/xmake.lua19
-rw-r--r--src/zencompute/zencompute.cpp21
-rw-r--r--src/zencore-test/zencore-test.cpp36
-rw-r--r--src/zencore/base64.cpp196
-rw-r--r--src/zencore/basicfile.cpp4
-rw-r--r--src/zencore/blake3.cpp6
-rw-r--r--src/zencore/callstack.cpp4
-rw-r--r--src/zencore/commandline.cpp1
-rw-r--r--src/zencore/compactbinary.cpp8
-rw-r--r--src/zencore/compactbinarybuilder.cpp4
-rw-r--r--src/zencore/compactbinaryjson.cpp4
-rw-r--r--src/zencore/compactbinarypackage.cpp4
-rw-r--r--src/zencore/compactbinaryvalidation.cpp4
-rw-r--r--src/zencore/compactbinaryyaml.cpp431
-rw-r--r--src/zencore/compositebuffer.cpp5
-rw-r--r--src/zencore/compress.cpp4
-rw-r--r--src/zencore/crypto.cpp4
-rw-r--r--src/zencore/filesystem.cpp36
-rw-r--r--src/zencore/include/zencore/base64.h4
-rw-r--r--src/zencore/include/zencore/blockingqueue.h2
-rw-r--r--src/zencore/include/zencore/compactbinaryfile.h1
-rw-r--r--src/zencore/include/zencore/compactbinaryvalue.h24
-rw-r--r--src/zencore/include/zencore/filesystem.h38
-rw-r--r--src/zencore/include/zencore/hashutils.h4
-rw-r--r--src/zencore/include/zencore/iobuffer.h37
-rw-r--r--src/zencore/include/zencore/logbase.h113
-rw-r--r--src/zencore/include/zencore/logging.h215
-rw-r--r--src/zencore/include/zencore/logging/ansicolorsink.h33
-rw-r--r--src/zencore/include/zencore/logging/asyncsink.h30
-rw-r--r--src/zencore/include/zencore/logging/formatter.h20
-rw-r--r--src/zencore/include/zencore/logging/helpers.h122
-rw-r--r--src/zencore/include/zencore/logging/logger.h63
-rw-r--r--src/zencore/include/zencore/logging/logmsg.h66
-rw-r--r--src/zencore/include/zencore/logging/memorybuffer.h11
-rw-r--r--src/zencore/include/zencore/logging/messageonlyformatter.h22
-rw-r--r--src/zencore/include/zencore/logging/msvcsink.h30
-rw-r--r--src/zencore/include/zencore/logging/nullsink.h17
-rw-r--r--src/zencore/include/zencore/logging/registry.h70
-rw-r--r--src/zencore/include/zencore/logging/sink.h34
-rw-r--r--src/zencore/include/zencore/logging/tracesink.h27
-rw-r--r--src/zencore/include/zencore/md5.h2
-rw-r--r--src/zencore/include/zencore/meta.h1
-rw-r--r--src/zencore/include/zencore/mpscqueue.h20
-rw-r--r--src/zencore/include/zencore/process.h34
-rw-r--r--src/zencore/include/zencore/sentryintegration.h9
-rw-r--r--src/zencore/include/zencore/sharedbuffer.h13
-rw-r--r--src/zencore/include/zencore/string.h40
-rw-r--r--src/zencore/include/zencore/system.h37
-rw-r--r--src/zencore/include/zencore/testing.h7
-rw-r--r--src/zencore/include/zencore/testutils.h27
-rw-r--r--src/zencore/include/zencore/thread.h24
-rw-r--r--src/zencore/include/zencore/trace.h1
-rw-r--r--src/zencore/include/zencore/varint.h1
-rw-r--r--src/zencore/include/zencore/xxhash.h2
-rw-r--r--src/zencore/include/zencore/zencore.h34
-rw-r--r--src/zencore/intmath.cpp10
-rw-r--r--src/zencore/iobuffer.cpp24
-rw-r--r--src/zencore/jobqueue.cpp24
-rw-r--r--src/zencore/logging.cpp342
-rw-r--r--src/zencore/logging/ansicolorsink.cpp273
-rw-r--r--src/zencore/logging/asyncsink.cpp212
-rw-r--r--src/zencore/logging/logger.cpp142
-rw-r--r--src/zencore/logging/msvcsink.cpp80
-rw-r--r--src/zencore/logging/registry.cpp330
-rw-r--r--src/zencore/logging/tracesink.cpp92
-rw-r--r--src/zencore/md5.cpp47
-rw-r--r--src/zencore/memoryview.cpp4
-rw-r--r--src/zencore/memtrack/callstacktrace.cpp8
-rw-r--r--src/zencore/memtrack/tagtrace.cpp2
-rw-r--r--src/zencore/mpscqueue.cpp6
-rw-r--r--src/zencore/parallelwork.cpp4
-rw-r--r--src/zencore/process.cpp329
-rw-r--r--src/zencore/refcount.cpp4
-rw-r--r--src/zencore/sentryintegration.cpp196
-rw-r--r--src/zencore/sha1.cpp4
-rw-r--r--src/zencore/sharedbuffer.cpp4
-rw-r--r--src/zencore/stream.cpp4
-rw-r--r--src/zencore/string.cpp210
-rw-r--r--src/zencore/system.cpp407
-rw-r--r--src/zencore/testing.cpp134
-rw-r--r--src/zencore/testutils.cpp2
-rw-r--r--src/zencore/thread.cpp42
-rw-r--r--src/zencore/trace.cpp22
-rw-r--r--src/zencore/uid.cpp4
-rw-r--r--src/zencore/windows.cpp12
-rw-r--r--src/zencore/workthreadpool.cpp4
-rw-r--r--src/zencore/xmake.lua4
-rw-r--r--src/zencore/xxhash.cpp4
-rw-r--r--src/zencore/zencore.cpp4
-rw-r--r--src/zenhorde/hordeagent.cpp297
-rw-r--r--src/zenhorde/hordeagent.h77
-rw-r--r--src/zenhorde/hordeagentmessage.cpp340
-rw-r--r--src/zenhorde/hordeagentmessage.h161
-rw-r--r--src/zenhorde/hordebundle.cpp619
-rw-r--r--src/zenhorde/hordebundle.h49
-rw-r--r--src/zenhorde/hordeclient.cpp382
-rw-r--r--src/zenhorde/hordecomputebuffer.cpp454
-rw-r--r--src/zenhorde/hordecomputebuffer.h136
-rw-r--r--src/zenhorde/hordecomputechannel.cpp37
-rw-r--r--src/zenhorde/hordecomputechannel.h32
-rw-r--r--src/zenhorde/hordecomputesocket.cpp204
-rw-r--r--src/zenhorde/hordecomputesocket.h79
-rw-r--r--src/zenhorde/hordeconfig.cpp89
-rw-r--r--src/zenhorde/hordeprovisioner.cpp367
-rw-r--r--src/zenhorde/hordetransport.cpp169
-rw-r--r--src/zenhorde/hordetransport.h71
-rw-r--r--src/zenhorde/hordetransportaes.cpp425
-rw-r--r--src/zenhorde/hordetransportaes.h52
-rw-r--r--src/zenhorde/include/zenhorde/hordeclient.h116
-rw-r--r--src/zenhorde/include/zenhorde/hordeconfig.h62
-rw-r--r--src/zenhorde/include/zenhorde/hordeprovisioner.h110
-rw-r--r--src/zenhorde/include/zenhorde/zenhorde.h9
-rw-r--r--src/zenhorde/xmake.lua22
-rw-r--r--src/zenhttp-test/zenhttp-test.cpp35
-rw-r--r--src/zenhttp/auth/oidc.cpp24
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp323
-rw-r--r--src/zenhttp/clients/httpclientcommon.h115
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp579
-rw-r--r--src/zenhttp/clients/httpclientcpr.h15
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp566
-rw-r--r--src/zenhttp/httpclient.cpp451
-rw-r--r--src/zenhttp/httpclient_test.cpp1366
-rw-r--r--src/zenhttp/httpclientauth.cpp2
-rw-r--r--src/zenhttp/httpserver.cpp167
-rw-r--r--src/zenhttp/include/zenhttp/cprutils.h4
-rw-r--r--src/zenhttp/include/zenhttp/formatters.h2
-rw-r--r--src/zenhttp/include/zenhttp/httpapiservice.h1
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h53
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h7
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h147
-rw-r--r--src/zenhttp/include/zenhttp/httpstats.h47
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h79
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h2
-rw-r--r--src/zenhttp/include/zenhttp/security/passwordsecurity.h38
-rw-r--r--src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h51
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h65
-rw-r--r--src/zenhttp/monitoring/httpstats.cpp195
-rw-r--r--src/zenhttp/packageformat.cpp6
-rw-r--r--src/zenhttp/security/passwordsecurity.cpp176
-rw-r--r--src/zenhttp/security/passwordsecurityfilter.cpp56
-rw-r--r--src/zenhttp/servers/httpasio.cpp427
-rw-r--r--src/zenhttp/servers/httpasio.h2
-rw-r--r--src/zenhttp/servers/httpmulti.cpp31
-rw-r--r--src/zenhttp/servers/httpmulti.h12
-rw-r--r--src/zenhttp/servers/httpnull.cpp18
-rw-r--r--src/zenhttp/servers/httpnull.h1
-rw-r--r--src/zenhttp/servers/httpparser.cpp155
-rw-r--r--src/zenhttp/servers/httpparser.h12
-rw-r--r--src/zenhttp/servers/httpplugin.cpp140
-rw-r--r--src/zenhttp/servers/httpsys.cpp556
-rw-r--r--src/zenhttp/servers/httpsys_iocontext.h40
-rw-r--r--src/zenhttp/servers/httptracer.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp311
-rw-r--r--src/zenhttp/servers/wsasio.h77
-rw-r--r--src/zenhttp/servers/wsframecodec.cpp236
-rw-r--r--src/zenhttp/servers/wsframecodec.h74
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp485
-rw-r--r--src/zenhttp/servers/wshttpsys.h107
-rw-r--r--src/zenhttp/servers/wstest.cpp925
-rw-r--r--src/zenhttp/transports/dlltransport.cpp38
-rw-r--r--src/zenhttp/transports/winsocktransport.cpp2
-rw-r--r--src/zenhttp/xmake.lua1
-rw-r--r--src/zenhttp/zenhttp.cpp4
-rw-r--r--src/zennet-test/zennet-test.cpp34
-rw-r--r--src/zennet/beacon.cpp170
-rw-r--r--src/zennet/include/zennet/beacon.h38
-rw-r--r--src/zennet/include/zennet/statsdclient.h2
-rw-r--r--src/zennet/statsdclient.cpp5
-rw-r--r--src/zennomad/include/zennomad/nomadclient.h77
-rw-r--r--src/zennomad/include/zennomad/nomadconfig.h65
-rw-r--r--src/zennomad/include/zennomad/nomadprocess.h78
-rw-r--r--src/zennomad/include/zennomad/nomadprovisioner.h107
-rw-r--r--src/zennomad/include/zennomad/zennomad.h9
-rw-r--r--src/zennomad/nomadclient.cpp366
-rw-r--r--src/zennomad/nomadconfig.cpp91
-rw-r--r--src/zennomad/nomadprocess.cpp354
-rw-r--r--src/zennomad/nomadprovisioner.cpp264
-rw-r--r--src/zennomad/xmake.lua10
-rw-r--r--src/zenremotestore-test/zenremotestore-test.cpp35
-rw-r--r--src/zenremotestore/builds/buildmanifest.cpp4
-rw-r--r--src/zenremotestore/builds/buildsavedstate.cpp4
-rw-r--r--src/zenremotestore/builds/buildstoragecache.cpp281
-rw-r--r--src/zenremotestore/builds/buildstorageoperations.cpp1636
-rw-r--r--src/zenremotestore/builds/buildstorageutil.cpp107
-rw-r--r--src/zenremotestore/builds/filebuildstorage.cpp39
-rw-r--r--src/zenremotestore/builds/jupiterbuildstorage.cpp22
-rw-r--r--src/zenremotestore/chunking/chunkblock.cpp1378
-rw-r--r--src/zenremotestore/chunking/chunkedcontent.cpp8
-rw-r--r--src/zenremotestore/chunking/chunkedfile.cpp4
-rw-r--r--src/zenremotestore/chunking/chunkingcache.cpp12
-rw-r--r--src/zenremotestore/filesystemutils.cpp4
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstorage.h19
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h17
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h76
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h33
-rw-r--r--src/zenremotestore/include/zenremotestore/chunking/chunkblock.h72
-rw-r--r--src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h2
-rw-r--r--src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h3
-rw-r--r--src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h12
-rw-r--r--src/zenremotestore/include/zenremotestore/operationlogoutput.h29
-rw-r--r--src/zenremotestore/include/zenremotestore/partialblockrequestmode.h20
-rw-r--r--src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h20
-rw-r--r--src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h97
-rw-r--r--src/zenremotestore/jupiter/jupiterhost.cpp17
-rw-r--r--src/zenremotestore/jupiter/jupitersession.cpp65
-rw-r--r--src/zenremotestore/operationlogoutput.cpp16
-rw-r--r--src/zenremotestore/partialblockrequestmode.cpp27
-rw-r--r--src/zenremotestore/projectstore/buildsremoteprojectstore.cpp316
-rw-r--r--src/zenremotestore/projectstore/fileremoteprojectstore.cpp255
-rw-r--r--src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp42
-rw-r--r--src/zenremotestore/projectstore/projectstoreoperations.cpp14
-rw-r--r--src/zenremotestore/projectstore/remoteprojectstore.cpp2522
-rw-r--r--src/zenremotestore/projectstore/zenremoteprojectstore.cpp50
-rw-r--r--src/zenserver-test/buildstore-tests.cpp220
-rw-r--r--src/zenserver-test/cache-tests.cpp14
-rw-r--r--src/zenserver-test/cacherequests.cpp4
-rw-r--r--src/zenserver-test/compute-tests.cpp1700
-rw-r--r--src/zenserver-test/hub-tests.cpp8
-rw-r--r--src/zenserver-test/logging-tests.cpp261
-rw-r--r--src/zenserver-test/nomad-tests.cpp130
-rw-r--r--src/zenserver-test/objectstore-tests.cpp74
-rw-r--r--src/zenserver-test/projectstore-tests.cpp38
-rw-r--r--src/zenserver-test/workspace-tests.cpp8
-rw-r--r--src/zenserver-test/xmake.lua7
-rw-r--r--src/zenserver-test/zenserver-test.cpp26
-rw-r--r--src/zenserver/compute/computeserver.cpp1021
-rw-r--r--src/zenserver/compute/computeserver.h188
-rw-r--r--src/zenserver/config/config.cpp93
-rw-r--r--src/zenserver/config/config.h44
-rw-r--r--src/zenserver/config/luaconfig.h2
-rw-r--r--src/zenserver/diag/diagsvcs.cpp37
-rw-r--r--src/zenserver/diag/diagsvcs.h15
-rw-r--r--src/zenserver/diag/logging.cpp61
-rw-r--r--src/zenserver/diag/otlphttp.cpp63
-rw-r--r--src/zenserver/diag/otlphttp.h28
-rw-r--r--src/zenserver/frontend/frontend.cpp56
-rw-r--r--src/zenserver/frontend/frontend.h7
-rw-r--r--src/zenserver/frontend/html.zipbin163229 -> 406051 bytes
-rw-r--r--src/zenserver/frontend/html/404.html486
-rw-r--r--src/zenserver/frontend/html/banner.js338
-rw-r--r--src/zenserver/frontend/html/compute/compute.html929
-rw-r--r--src/zenserver/frontend/html/compute/hub.html170
-rw-r--r--src/zenserver/frontend/html/compute/index.html1
-rw-r--r--src/zenserver/frontend/html/compute/orchestrator.html674
-rw-r--r--src/zenserver/frontend/html/epicgames.ico (renamed from src/UnrealEngine.ico)bin65288 -> 65288 bytes
-rw-r--r--src/zenserver/frontend/html/favicon.icobin65288 -> 12957 bytes
-rw-r--r--src/zenserver/frontend/html/index.html3
-rw-r--r--src/zenserver/frontend/html/nav.js79
-rw-r--r--src/zenserver/frontend/html/pages/cache.js690
-rw-r--r--src/zenserver/frontend/html/pages/compute.js693
-rw-r--r--src/zenserver/frontend/html/pages/cookartifacts.js397
-rw-r--r--src/zenserver/frontend/html/pages/entry.js341
-rw-r--r--src/zenserver/frontend/html/pages/hub.js122
-rw-r--r--src/zenserver/frontend/html/pages/info.js261
-rw-r--r--src/zenserver/frontend/html/pages/map.js4
-rw-r--r--src/zenserver/frontend/html/pages/metrics.js232
-rw-r--r--src/zenserver/frontend/html/pages/oplog.js4
-rw-r--r--src/zenserver/frontend/html/pages/orchestrator.js405
-rw-r--r--src/zenserver/frontend/html/pages/page.js120
-rw-r--r--src/zenserver/frontend/html/pages/project.js2
-rw-r--r--src/zenserver/frontend/html/pages/projects.js447
-rw-r--r--src/zenserver/frontend/html/pages/sessions.js61
-rw-r--r--src/zenserver/frontend/html/pages/start.js327
-rw-r--r--src/zenserver/frontend/html/pages/stat.js2
-rw-r--r--src/zenserver/frontend/html/pages/tree.js2
-rw-r--r--src/zenserver/frontend/html/pages/zcache.js8
-rw-r--r--src/zenserver/frontend/html/theme.js116
-rw-r--r--src/zenserver/frontend/html/util/compactbinary.js4
-rw-r--r--src/zenserver/frontend/html/util/friendly.js21
-rw-r--r--src/zenserver/frontend/html/util/widgets.js138
-rw-r--r--src/zenserver/frontend/html/zen.css824
-rw-r--r--src/zenserver/frontend/zipfs.cpp20
-rw-r--r--src/zenserver/frontend/zipfs.h8
-rw-r--r--src/zenserver/hub/hubservice.cpp68
-rw-r--r--src/zenserver/hub/hubservice.h7
-rw-r--r--src/zenserver/hub/zenhubserver.cpp11
-rw-r--r--src/zenserver/hub/zenhubserver.h6
-rw-r--r--src/zenserver/main.cpp69
-rw-r--r--src/zenserver/sessions/httpsessions.cpp264
-rw-r--r--src/zenserver/sessions/httpsessions.h55
-rw-r--r--src/zenserver/sessions/sessions.cpp150
-rw-r--r--src/zenserver/sessions/sessions.h83
-rw-r--r--src/zenserver/storage/admin/admin.cpp6
-rw-r--r--src/zenserver/storage/buildstore/httpbuildstore.cpp153
-rw-r--r--src/zenserver/storage/buildstore/httpbuildstore.h7
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.cpp141
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.h11
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.cpp279
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.h5
-rw-r--r--src/zenserver/storage/storageconfig.cpp1
-rw-r--r--src/zenserver/storage/storageconfig.h2
-rw-r--r--src/zenserver/storage/workspaces/httpworkspaces.cpp12
-rw-r--r--src/zenserver/storage/workspaces/httpworkspaces.h5
-rw-r--r--src/zenserver/storage/zenstorageserver.cpp55
-rw-r--r--src/zenserver/storage/zenstorageserver.h30
-rw-r--r--src/zenserver/trace/tracerecorder.cpp565
-rw-r--r--src/zenserver/trace/tracerecorder.h46
-rw-r--r--src/zenserver/xmake.lua40
-rw-r--r--src/zenserver/zenserver.cpp167
-rw-r--r--src/zenserver/zenserver.h38
-rw-r--r--src/zenserver/zenserver.rc2
-rw-r--r--src/zenstore-test/zenstore-test.cpp34
-rw-r--r--src/zenstore/blockstore.cpp4
-rw-r--r--src/zenstore/buildstore/buildstore.cpp23
-rw-r--r--src/zenstore/cache/cachedisklayer.cpp40
-rw-r--r--src/zenstore/cache/cachepolicy.cpp19
-rw-r--r--src/zenstore/cache/cacherpc.cpp6
-rw-r--r--src/zenstore/cache/structuredcachestore.cpp19
-rw-r--r--src/zenstore/cas.cpp27
-rw-r--r--src/zenstore/caslog.cpp6
-rw-r--r--src/zenstore/cidstore.cpp3
-rw-r--r--src/zenstore/compactcas.cpp20
-rw-r--r--src/zenstore/filecas.cpp18
-rw-r--r--src/zenstore/filecas.h2
-rw-r--r--src/zenstore/gc.cpp11
-rw-r--r--src/zenstore/include/zenstore/buildstore/buildstore.h4
-rw-r--r--src/zenstore/include/zenstore/cache/cachedisklayer.h34
-rw-r--r--src/zenstore/include/zenstore/cache/cacheshared.h6
-rw-r--r--src/zenstore/include/zenstore/cache/structuredcachestore.h12
-rw-r--r--src/zenstore/include/zenstore/caslog.h10
-rw-r--r--src/zenstore/include/zenstore/gc.h8
-rw-r--r--src/zenstore/include/zenstore/projectstore.h10
-rw-r--r--src/zenstore/projectstore.cpp39
-rw-r--r--src/zenstore/workspaces.cpp20
-rw-r--r--src/zentelemetry-test/zentelemetry-test.cpp34
-rw-r--r--src/zentelemetry/include/zentelemetry/otlpencoder.h8
-rw-r--r--src/zentelemetry/include/zentelemetry/otlptrace.h9
-rw-r--r--src/zentelemetry/include/zentelemetry/stats.h203
-rw-r--r--src/zentelemetry/otlpencoder.cpp44
-rw-r--r--src/zentelemetry/otlptrace.cpp4
-rw-r--r--src/zentelemetry/stats.cpp6
-rw-r--r--src/zentelemetry/xmake.lua2
-rw-r--r--src/zentest-appstub/xmake.lua1
-rw-r--r--src/zentest-appstub/zentest-appstub.cpp401
-rw-r--r--src/zenutil-test/zenutil-test.cpp34
-rw-r--r--src/zenutil/config/commandlineoptions.cpp (renamed from src/zenutil/commandlineoptions.cpp)7
-rw-r--r--src/zenutil/config/environmentoptions.cpp (renamed from src/zenutil/environmentoptions.cpp)2
-rw-r--r--src/zenutil/config/loggingconfig.cpp77
-rw-r--r--src/zenutil/consoletui.cpp483
-rw-r--r--src/zenutil/include/zenutil/config/commandlineoptions.h (renamed from src/zenutil/include/zenutil/commandlineoptions.h)0
-rw-r--r--src/zenutil/include/zenutil/config/environmentoptions.h (renamed from src/zenutil/include/zenutil/environmentoptions.h)2
-rw-r--r--src/zenutil/include/zenutil/config/loggingconfig.h37
-rw-r--r--src/zenutil/include/zenutil/consoletui.h60
-rw-r--r--src/zenutil/include/zenutil/logging.h11
-rw-r--r--src/zenutil/include/zenutil/logging/fullformatter.h89
-rw-r--r--src/zenutil/include/zenutil/logging/jsonformatter.h168
-rw-r--r--src/zenutil/include/zenutil/logging/rotatingfilesink.h89
-rw-r--r--src/zenutil/include/zenutil/logging/testformatter.h160
-rw-r--r--src/zenutil/include/zenutil/zenserverprocess.h24
-rw-r--r--src/zenutil/logging.cpp149
-rw-r--r--src/zenutil/rpcrecording.cpp2
-rw-r--r--src/zenutil/wildcard.cpp4
-rw-r--r--src/zenutil/xmake.lua2
-rw-r--r--src/zenutil/zenserverprocess.cpp31
-rw-r--r--src/zenutil/zenutil.cpp2
-rw-r--r--src/zenvfs/xmake.lua2
429 files changed, 54397 insertions, 5283 deletions
diff --git a/src/zen.ico b/src/zen.ico
new file mode 100644
index 000000000..f7fb251b5
--- /dev/null
+++ b/src/zen.ico
Binary files differ
diff --git a/src/zen/authutils.cpp b/src/zen/authutils.cpp
index 31db82efd..534f7952b 100644
--- a/src/zen/authutils.cpp
+++ b/src/zen/authutils.cpp
@@ -154,21 +154,34 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
ZEN_ASSERT(!SystemRootDir.empty());
if (!Auth)
{
- if (m_EncryptionKey.empty())
+ static const std::string_view DefaultEncryptionKey("abcdefghijklmnopqrstuvxyz0123456");
+ static const std::string_view DefaultEncryptionIV("0123456789abcdef");
+ if (m_EncryptionKey.empty() && m_EncryptionIV.empty())
{
- m_EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456";
+ m_EncryptionKey = DefaultEncryptionKey;
+ m_EncryptionIV = DefaultEncryptionIV;
if (!Quiet)
{
- ZEN_CONSOLE_WARN("Using default encryption key");
+ ZEN_CONSOLE_WARN("Auth: Using default encryption key and initialization vector for auth storage");
}
}
-
- if (m_EncryptionIV.empty())
+ else
{
- m_EncryptionIV = "0123456789abcdef";
- if (!Quiet)
+ if (m_EncryptionKey.empty())
+ {
+ m_EncryptionKey = DefaultEncryptionKey;
+ if (!Quiet)
+ {
+ ZEN_CONSOLE_WARN("Auth: Using default encryption key for auth storage");
+ }
+ }
+ if (m_EncryptionIV.empty())
{
- ZEN_CONSOLE_WARN("Using default encryption initialization vector");
+ m_EncryptionIV = DefaultEncryptionIV;
+ if (!Quiet)
+ {
+ ZEN_CONSOLE_WARN("Auth: Using default encryption initialization vector for auth storage");
+ }
}
}
@@ -187,9 +200,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
{
ExtendableStringBuilder<128> SB;
SB << "\n RootDirectory: " << AuthMgrConfig.RootDirectory.string();
- SB << "\n EncryptionKey: " << m_EncryptionKey;
- SB << "\n EncryptionIV: " << m_EncryptionIV;
- ZEN_CONSOLE("Creating auth manager with:{}", SB.ToString());
+ SB << "\n EncryptionKey: " << HideSensitiveString(m_EncryptionKey);
+ SB << "\n EncryptionIV: " << HideSensitiveString(m_EncryptionIV);
+ ZEN_CONSOLE("Auth: Creating auth manager with:{}", SB.ToString());
}
Auth = AuthMgr::Create(AuthMgrConfig);
}
@@ -204,13 +217,18 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
ExtendableStringBuilder<128> SB;
SB << "\n Name: " << ProviderName;
SB << "\n Url: " << m_OpenIdProviderUrl;
- SB << "\n ClientId: " << m_OpenIdClientId;
- ZEN_CONSOLE("Adding openid auth provider:{}", SB.ToString());
+ SB << "\n ClientId: " << HideSensitiveString(m_OpenIdClientId);
+ ZEN_CONSOLE("Auth: Adding Open ID auth provider:{}", SB.ToString());
}
Auth->AddOpenIdProvider({.Name = ProviderName, .Url = m_OpenIdProviderUrl, .ClientId = m_OpenIdClientId});
if (!m_OpenIdRefreshToken.empty())
{
- ZEN_CONSOLE("Adding open id refresh token {} to provider {}", m_OpenIdRefreshToken, ProviderName);
+ if (!Quiet)
+ {
+ ZEN_CONSOLE("Auth: Adding open id refresh token {} to provider {}",
+ HideSensitiveString(m_OpenIdRefreshToken),
+ ProviderName);
+ }
Auth->AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = m_OpenIdRefreshToken});
}
}
@@ -225,21 +243,21 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
if (!m_AccessToken.empty())
{
- if (Verbose)
+ if (!Quiet)
{
- ZEN_CONSOLE("Adding static auth token: {}", m_AccessToken);
+ ZEN_CONSOLE("Auth: Using static auth token: {}", HideSensitiveString(m_AccessToken));
}
ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(m_AccessToken);
}
else if (!m_AccessTokenPath.empty())
{
- MakeSafeAbsolutePathÍnPlace(m_AccessTokenPath);
+ MakeSafeAbsolutePathInPlace(m_AccessTokenPath);
std::string ResolvedAccessToken = ReadAccessTokenFromJsonFile(m_AccessTokenPath);
if (!ResolvedAccessToken.empty())
{
- if (Verbose)
+ if (!Quiet)
{
- ZEN_CONSOLE("Adding static auth token from {}: {}", m_AccessTokenPath, ResolvedAccessToken);
+ ZEN_CONSOLE("Auth: Adding static auth token from {}: {}", m_AccessTokenPath, HideSensitiveString(ResolvedAccessToken));
}
ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken);
}
@@ -250,9 +268,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
{
ExtendableStringBuilder<128> SB;
SB << "\n Url: " << m_OAuthUrl;
- SB << "\n ClientId: " << m_OAuthClientId;
- SB << "\n ClientSecret: " << m_OAuthClientSecret;
- ZEN_CONSOLE("Adding oauth provider:{}", SB.ToString());
+ SB << "\n ClientId: " << HideSensitiveString(m_OAuthClientId);
+ SB << "\n ClientSecret: " << HideSensitiveString(m_OAuthClientSecret);
+ ZEN_CONSOLE("Auth: Adding oauth provider:{}", SB.ToString());
}
ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOAuthClientCredentials(
{.Url = m_OAuthUrl, .ClientId = m_OAuthClientId, .ClientSecret = m_OAuthClientSecret});
@@ -260,25 +278,27 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
else if (!m_OpenIdProviderName.empty())
{
CreateAuthMgr();
- if (Verbose)
+ if (!Quiet)
{
- ZEN_CONSOLE("Using openid provider: {}", m_OpenIdProviderName);
+ ZEN_CONSOLE("Auth: Using OpenId provider: {}", m_OpenIdProviderName);
}
ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOpenIdProvider(*Auth, m_OpenIdProviderName);
}
else if (std::string ResolvedAccessToken = GetEnvAccessToken(m_AccessTokenEnv); !ResolvedAccessToken.empty())
{
- if (Verbose)
+ if (!Quiet)
{
- ZEN_CONSOLE("Using environment variable '{}' as access token '{}'", m_AccessTokenEnv, ResolvedAccessToken);
+ ZEN_CONSOLE("Auth: Resolved environment variable '{}' to access token '{}'",
+ m_AccessTokenEnv,
+ HideSensitiveString(ResolvedAccessToken));
}
ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken);
}
else if (std::filesystem::path OidcTokenExePath = FindOidcTokenExePath(m_OidcTokenAuthExecutablePath); !OidcTokenExePath.empty())
{
- if (Verbose)
+ if (!Quiet)
{
- ZEN_CONSOLE("Running oidctoken exe from path '{}'", m_OidcTokenAuthExecutablePath);
+ ZEN_CONSOLE("Auth: Using oidctoken exe from path '{}'", OidcTokenExePath);
}
ClientSettings.AccessTokenProvider =
httpclientauth::CreateFromOidcTokenExecutable(OidcTokenExePath, HostUrl, Quiet, m_OidcTokenUnattended, Hidden);
@@ -291,9 +311,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops,
if (!ClientSettings.AccessTokenProvider)
{
CreateAuthMgr();
- if (Verbose)
+ if (!Quiet)
{
- ZEN_CONSOLE("Using default openid provider");
+ ZEN_CONSOLE("Auth: Using default Open ID provider");
}
ClientSettings.AccessTokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(*Auth);
}
diff --git a/src/zen/cmds/admin_cmd.h b/src/zen/cmds/admin_cmd.h
index 87ef8091b..83bcf8893 100644
--- a/src/zen/cmds/admin_cmd.h
+++ b/src/zen/cmds/admin_cmd.h
@@ -13,6 +13,9 @@ namespace zen {
class ScrubCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "scrub";
+ static constexpr char Description[] = "Scrub zen storage (verify data integrity)";
+
ScrubCommand();
~ScrubCommand();
@@ -20,7 +23,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"scrub", "Scrub zen storage"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
bool m_DryRun = false;
bool m_NoGc = false;
@@ -33,6 +36,9 @@ private:
class GcCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "gc";
+ static constexpr char Description[] = "Garbage collect zen storage";
+
GcCommand();
~GcCommand();
@@ -40,7 +46,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"gc", "Garbage collect zen storage"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
bool m_SmallObjects{false};
bool m_SkipCid{false};
@@ -62,6 +68,9 @@ private:
class GcStatusCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "gc-status";
+ static constexpr char Description[] = "Garbage collect zen storage status check";
+
GcStatusCommand();
~GcStatusCommand();
@@ -69,7 +78,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"gc-status", "Garbage collect zen storage status check"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
bool m_Details = false;
};
@@ -77,6 +86,9 @@ private:
class GcStopCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "gc-stop";
+ static constexpr char Description[] = "Request cancel of running garbage collection in zen storage";
+
GcStopCommand();
~GcStopCommand();
@@ -84,7 +96,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"gc-stop", "Request cancel of running garbage collection in zen storage"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
};
@@ -93,6 +105,9 @@ private:
class JobCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "jobs";
+ static constexpr char Description[] = "Show/cancel zen background jobs";
+
JobCommand();
~JobCommand();
@@ -100,7 +115,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"jobs", "Show/cancel zen background jobs"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::uint64_t m_JobId = 0;
bool m_Cancel = 0;
@@ -111,6 +126,9 @@ private:
class LoggingCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "logs";
+ static constexpr char Description[] = "Show/control zen logging";
+
LoggingCommand();
~LoggingCommand();
@@ -118,7 +136,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"logs", "Show/control zen logging"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_CacheWriteLog;
std::string m_CacheAccessLog;
@@ -133,6 +151,9 @@ private:
class FlushCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "flush";
+ static constexpr char Description[] = "Flush storage";
+
FlushCommand();
~FlushCommand();
@@ -140,7 +161,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"flush", "Flush zen storage"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
};
@@ -149,6 +170,9 @@ private:
class CopyStateCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "copy-state";
+ static constexpr char Description[] = "Copy zen server disk state";
+
CopyStateCommand();
~CopyStateCommand();
@@ -156,7 +180,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"copy-state", "Copy zen server disk state"};
+ cxxopts::Options m_Options{Name, Description};
std::filesystem::path m_DataPath;
std::filesystem::path m_TargetPath;
bool m_SkipLogs = false;
diff --git a/src/zen/cmds/bench_cmd.h b/src/zen/cmds/bench_cmd.h
index ed123be75..7fbf85340 100644
--- a/src/zen/cmds/bench_cmd.h
+++ b/src/zen/cmds/bench_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class BenchCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "bench";
+ static constexpr char Description[] = "Utility command for benchmarking";
+
BenchCommand();
~BenchCommand();
@@ -17,7 +20,7 @@ public:
virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"bench", "Benchmarking utility command"};
+ cxxopts::Options m_Options{Name, Description};
bool m_PurgeStandbyLists = false;
bool m_SingleProcess = false;
};
diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp
index f4edb65ab..b4b4df7c9 100644
--- a/src/zen/cmds/builds_cmd.cpp
+++ b/src/zen/cmds/builds_cmd.cpp
@@ -67,13 +67,11 @@ ZEN_THIRD_PARTY_INCLUDES_END
static const bool DoExtraContentVerify = false;
-#define ZEN_CLOUD_STORAGE "Cloud Storage"
-
namespace zen {
using namespace std::literals;
-namespace {
+namespace builds_impl {
static std::atomic<bool> AbortFlag = false;
static std::atomic<bool> PauseFlag = false;
@@ -270,10 +268,11 @@ namespace {
static bool IsQuiet = false;
static ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty;
-#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \
- if (IsVerbose) \
- { \
- ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__); \
+#undef ZEN_CONSOLE_VERBOSE
+#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \
+ if (IsVerbose) \
+ { \
+ ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__); \
}
const std::string DefaultAccessTokenEnvVariableName(
@@ -1467,9 +1466,16 @@ namespace {
ZEN_CONSOLE("Downloading build {}, parts:{} to '{}' ({})", BuildId, BuildPartString.ToView(), Path, NiceBytes(RawSize));
}
+ Stopwatch IndexTimer;
+
const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalState.State.ChunkedContent);
const ChunkedContentLookup RemoteLookup = BuildChunkedContentLookup(RemoteContent);
+ if (!IsQuiet)
+ {
+ ZEN_OPERATION_LOG_INFO(Output, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs()));
+ }
+
ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Download, TaskSteps::StepCount);
BuildsOperationUpdateFolder Updater(
@@ -1588,7 +1594,7 @@ namespace {
}
}
}
- if (Storage.BuildCacheStorage)
+ if (Storage.CacheStorage)
{
if (SB.Size() > 0)
{
@@ -1643,9 +1649,9 @@ namespace {
}
if (Options.PrimeCacheOnly)
{
- if (Storage.BuildCacheStorage)
+ if (Storage.CacheStorage)
{
- Storage.BuildCacheStorage->Flush(5000, [](intptr_t Remaining) {
+ Storage.CacheStorage->Flush(5000, [](intptr_t Remaining) {
if (!IsQuiet)
{
if (Remaining == 0)
@@ -2002,12 +2008,13 @@ namespace {
ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Cleanup, TaskSteps::StepCount);
}
-} // namespace
+} // namespace builds_impl
//////////////////////////////////////////////////////////////////////////////////////////////////////
BuildsCommand::BuildsCommand()
{
+ using namespace builds_impl;
m_Options.add_options()("h,help", "Print help");
auto AddSystemOptions = [this](cxxopts::Options& Ops) {
@@ -2648,6 +2655,7 @@ BuildsCommand::~BuildsCommand() = default;
void
BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace builds_impl;
ZEN_UNUSED(GlobalOptions);
signal(SIGINT, SignalCallbackHandler);
@@ -2680,7 +2688,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_SystemRootDir = PickDefaultSystemRootDirectory();
}
- MakeSafeAbsolutePathÍnPlace(m_SystemRootDir);
+ MakeSafeAbsolutePathInPlace(m_SystemRootDir);
};
ParseSystemOptions();
@@ -2729,7 +2737,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
throw OptionParseException("'--host', '--url', '--override-host' or '--storage-path' is required", SubOption->help());
}
- MakeSafeAbsolutePathÍnPlace(m_StoragePath);
+ MakeSafeAbsolutePathInPlace(m_StoragePath);
};
auto ParseOutputOptions = [&]() {
@@ -2800,8 +2808,6 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
.Verbose = m_VerboseHttp,
.MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)};
- std::unique_ptr<AuthMgr> Auth;
-
std::string StorageDescription;
std::string CacheDescription;
@@ -2820,44 +2826,47 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
BuildStorageResolveResult ResolveRes =
ResolveBuildStorage(*Output, ClientSettings, m_Host, m_OverrideHost, m_ZenCacheHost, ZenCacheResolveMode::All, m_Verbose);
- if (!ResolveRes.HostUrl.empty())
+ if (!ResolveRes.Cloud.Address.empty())
{
- ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2;
+ ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2;
Result.BuildStorageHttp =
- std::make_unique<HttpClient>(ResolveRes.HostUrl, ClientSettings, []() { return AbortFlag.load(); });
+ std::make_unique<HttpClient>(ResolveRes.Cloud.Address, ClientSettings, []() { return AbortFlag.load(); });
- Result.BuildStorage = CreateJupiterBuildStorage(Log(),
+ Result.BuildStorage = CreateJupiterBuildStorage(Log(),
*Result.BuildStorageHttp,
StorageStats,
m_Namespace,
m_Bucket,
m_AllowRedirect,
TempPath / "storage");
- Result.StorageName = ResolveRes.HostName;
+ Result.BuildStorageHost = ResolveRes.Cloud;
+
+ uint64_t HostLatencyNs = ResolveRes.Cloud.LatencySec >= 0 ? uint64_t(ResolveRes.Cloud.LatencySec * 1000000000.0) : 0;
- StorageDescription = fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'",
- ResolveRes.HostName,
- (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl),
- Result.BuildStorageHttp->GetSessionId(),
- m_Namespace,
- m_Bucket);
- ;
+ StorageDescription =
+ fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'. Latency: {}",
+ ResolveRes.Cloud.Name,
+ (ResolveRes.Cloud.Address == ResolveRes.Cloud.Name) ? "" : fmt::format(" {}", ResolveRes.Cloud.Address),
+ Result.BuildStorageHttp->GetSessionId(),
+ m_Namespace,
+ m_Bucket,
+ NiceLatencyNs(HostLatencyNs));
- if (!ResolveRes.CacheUrl.empty())
+ if (!ResolveRes.Cache.Address.empty())
{
Result.CacheHttp = std::make_unique<HttpClient>(
- ResolveRes.CacheUrl,
+ ResolveRes.Cache.Address,
HttpClientSettings{
.LogCategory = "httpcacheclient",
.ConnectTimeout = std::chrono::milliseconds{3000},
.Timeout = std::chrono::milliseconds{30000},
- .AssumeHttp2 = ResolveRes.CacheAssumeHttp2,
+ .AssumeHttp2 = ResolveRes.Cache.AssumeHttp2,
.AllowResume = true,
.RetryCount = 0,
.Verbose = m_VerboseHttp,
.MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)},
[]() { return AbortFlag.load(); });
- Result.BuildCacheStorage =
+ Result.CacheStorage =
CreateZenBuildStorageCache(*Result.CacheHttp,
StorageCacheStats,
m_Namespace,
@@ -2865,14 +2874,17 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
TempPath / "zencache",
BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background)
: GetTinyWorkerPool(EWorkloadType::Background));
- Result.CacheName = ResolveRes.CacheName;
+ Result.CacheHost = ResolveRes.Cache;
+
+ uint64_t CacheLatencyNs = ResolveRes.Cache.LatencySec >= 0 ? uint64_t(ResolveRes.Cache.LatencySec * 1000000000.0) : 0;
CacheDescription =
- fmt::format("Zen {}{}. SessionId: '{}'",
- ResolveRes.CacheName,
- (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl),
- Result.CacheHttp->GetSessionId());
- ;
+ fmt::format("Zen {}{}. SessionId: '{}'. Latency: {}",
+ ResolveRes.Cache.Name,
+ (ResolveRes.Cache.Address == ResolveRes.Cache.Name) ? "" : fmt::format(" {}", ResolveRes.Cache.Address),
+ Result.CacheHttp->GetSessionId(),
+ NiceLatencyNs(CacheLatencyNs));
+
if (!m_Namespace.empty())
{
CacheDescription += fmt::format(". Namespace '{}'", m_Namespace);
@@ -2888,41 +2900,56 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
StorageDescription = fmt::format("folder {}", m_StoragePath);
Result.BuildStorage = CreateFileBuildStorage(m_StoragePath, StorageStats, false, DefaultLatency, DefaultDelayPerKBSec);
- Result.StorageName = fmt::format("Disk {}", m_StoragePath.stem());
+
+ Result.BuildStorageHost = BuildStorageResolveResult::Host{.Address = m_StoragePath.generic_string(),
+ .Name = "Disk",
+ .LatencySec = 1.0 / 100000, // 1 us
+ .Caps = {.MaxRangeCountPerRequest = 2048u}};
if (!m_ZenCacheHost.empty())
{
- Result.CacheHttp = std::make_unique<HttpClient>(
- m_ZenCacheHost,
- HttpClientSettings{
- .LogCategory = "httpcacheclient",
- .ConnectTimeout = std::chrono::milliseconds{3000},
- .Timeout = std::chrono::milliseconds{30000},
- .AssumeHttp2 = m_AssumeHttp2,
- .AllowResume = true,
- .RetryCount = 0,
- .Verbose = m_VerboseHttp,
- .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)},
- []() { return AbortFlag.load(); });
- Result.BuildCacheStorage =
- CreateZenBuildStorageCache(*Result.CacheHttp,
- StorageCacheStats,
- m_Namespace,
- m_Bucket,
- TempPath / "zencache",
- BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background)
- : GetTinyWorkerPool(EWorkloadType::Background));
- Result.CacheName = m_ZenCacheHost;
-
- CacheDescription = fmt::format("Zen {}{}. SessionId: '{}'", Result.CacheName, "", Result.CacheHttp->GetSessionId());
- ;
- if (!m_Namespace.empty())
- {
- CacheDescription += fmt::format(". Namespace '{}'", m_Namespace);
- }
- if (!m_Bucket.empty())
+ ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(m_ZenCacheHost, m_AssumeHttp2, m_VerboseHttp);
+
+ if (TestResult.Success)
{
- CacheDescription += fmt::format(" Bucket '{}'", m_Bucket);
+ Result.CacheHttp = std::make_unique<HttpClient>(
+ m_ZenCacheHost,
+ HttpClientSettings{
+ .LogCategory = "httpcacheclient",
+ .ConnectTimeout = std::chrono::milliseconds{3000},
+ .Timeout = std::chrono::milliseconds{30000},
+ .AssumeHttp2 = m_AssumeHttp2,
+ .AllowResume = true,
+ .RetryCount = 0,
+ .Verbose = m_VerboseHttp,
+ .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)},
+ []() { return AbortFlag.load(); });
+
+ Result.CacheStorage =
+ CreateZenBuildStorageCache(*Result.CacheHttp,
+ StorageCacheStats,
+ m_Namespace,
+ m_Bucket,
+ TempPath / "zencache",
+ BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background)
+ : GetTinyWorkerPool(EWorkloadType::Background));
+ Result.CacheHost =
+ BuildStorageResolveResult::Host{.Address = m_ZenCacheHost,
+ .Name = m_ZenCacheHost,
+ .AssumeHttp2 = m_AssumeHttp2,
+ .LatencySec = TestResult.LatencySeconds,
+ .Caps = {.MaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest}};
+
+ CacheDescription = fmt::format("Zen {}. SessionId: '{}'", Result.CacheHost.Name, Result.CacheHttp->GetSessionId());
+
+ if (!m_Namespace.empty())
+ {
+ CacheDescription += fmt::format(". Namespace '{}'", m_Namespace);
+ }
+ if (!m_Bucket.empty())
+ {
+ CacheDescription += fmt::format(" Bucket '{}'", m_Bucket);
+ }
}
}
}
@@ -2934,7 +2961,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (!IsQuiet)
{
ZEN_CONSOLE("Remote: {}", StorageDescription);
- if (!Result.CacheName.empty())
+ if (!Result.CacheHost.Name.empty())
{
ZEN_CONSOLE("Cache : {}", CacheDescription);
}
@@ -2947,7 +2974,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
throw OptionParseException("'--local-path' is required", SubOption->help());
}
- MakeSafeAbsolutePathÍnPlace(m_Path);
+ MakeSafeAbsolutePathInPlace(m_Path);
};
auto ParseFileFilters = [&](std::vector<std::string>& OutIncludeWildcards, std::vector<std::string>& OutExcludeWildcards) {
@@ -3004,7 +3031,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
throw OptionParseException("'--compare-path' is required", SubOption->help());
}
- MakeSafeAbsolutePathÍnPlace(m_DiffPath);
+ MakeSafeAbsolutePathInPlace(m_DiffPath);
};
auto ParseBlobHash = [&]() -> IoHash {
@@ -3016,7 +3043,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (m_BlobHash.length() != IoHash::StringLength)
{
throw OptionParseException(
- fmt::format("'--blob-hash' ('{}') is malfomed, it must be {} characters long", m_BlobHash, IoHash::StringLength),
+ fmt::format("'--blob-hash' ('{}') is malformed, it must be {} characters long", m_BlobHash, IoHash::StringLength),
SubOption->help());
}
@@ -3033,7 +3060,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (m_BuildId.length() != Oid::StringLength)
{
throw OptionParseException(
- fmt::format("'--build-id' ('{}') is malfomed, it must be {} characters long", m_BuildId, Oid::StringLength),
+ fmt::format("'--build-id' ('{}') is malformed, it must be {} characters long", m_BuildId, Oid::StringLength),
SubOption->help());
}
else if (Oid BuildId = Oid::FromHexString(m_BuildId); BuildId == Oid::Zero)
@@ -3105,7 +3132,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (!m_BuildMetadataPath.empty())
{
- MakeSafeAbsolutePathÍnPlace(m_BuildMetadataPath);
+ MakeSafeAbsolutePathInPlace(m_BuildMetadataPath);
IoBuffer MetaDataJson = ReadFile(m_BuildMetadataPath).Flatten();
std::string_view Json(reinterpret_cast<const char*>(MetaDataJson.GetData()), MetaDataJson.GetSize());
std::string JsonError;
@@ -3202,8 +3229,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (SubOption == &m_ListOptions)
{
- MakeSafeAbsolutePathÍnPlace(m_ListQueryPath);
- MakeSafeAbsolutePathÍnPlace(m_ListResultPath);
+ MakeSafeAbsolutePathInPlace(m_ListQueryPath);
+ MakeSafeAbsolutePathInPlace(m_ListResultPath);
if (!m_ListResultPath.empty())
{
@@ -3255,7 +3282,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
CreateDirectories(m_ZenFolderPath);
auto _ = MakeGuard([this]() { CleanAndRemoveDirectory(GetSmallWorkerPool(EWorkloadType::Burst), m_ZenFolderPath); });
@@ -3294,7 +3321,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (SubOption == &m_ListBlocksOptions)
{
- MakeSafeAbsolutePathÍnPlace(m_ListResultPath);
+ MakeSafeAbsolutePathInPlace(m_ListResultPath);
if (!m_ListResultPath.empty())
{
@@ -3316,7 +3343,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
CreateDirectories(m_ZenFolderPath);
auto _ = MakeGuard([this]() { CleanAndRemoveDirectory(GetSmallWorkerPool(EWorkloadType::Burst), m_ZenFolderPath); });
@@ -3387,8 +3414,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
- MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ChunkingCachePath);
CreateDirectories(m_ZenFolderPath);
auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); });
@@ -3475,7 +3502,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
"Requests: {}\n"
"Avg Request Time: {}\n"
"Avg I/O Time: {}",
- Storage.StorageName,
+ Storage.BuildStorageHost.Name,
NiceBytes(StorageStats.TotalBytesRead.load()),
NiceBytes(StorageStats.TotalBytesWritten.load()),
StorageStats.TotalRequestCount.load(),
@@ -3532,7 +3559,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = m_Path / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
BuildStorageBase::Statistics StorageStats;
BuildStorageCache::Statistics StorageCacheStats;
@@ -3632,7 +3659,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = m_Path / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
BuildStorageBase::Statistics StorageStats;
BuildStorageCache::Statistics StorageCacheStats;
@@ -3652,7 +3679,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
std::unique_ptr<CbObjectWriter> StructuredOutput;
if (!m_LsResultPath.empty())
{
- MakeSafeAbsolutePathÍnPlace(m_LsResultPath);
+ MakeSafeAbsolutePathInPlace(m_LsResultPath);
StructuredOutput = std::make_unique<CbObjectWriter>();
}
@@ -3696,7 +3723,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
ParsePath();
ParseDiffPath();
- MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath);
+ MakeSafeAbsolutePathInPlace(m_ChunkingCachePath);
std::vector<std::string> ExcludeFolders = DefaultExcludeFolders;
std::vector<std::string> ExcludeExtensions = DefaultExcludeExtensions;
@@ -3745,7 +3772,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
CreateDirectories(m_ZenFolderPath);
auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); });
@@ -3796,12 +3823,12 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
if (!IsQuiet)
{
- if (Storage.BuildCacheStorage)
+ if (Storage.CacheStorage)
{
- ZEN_CONSOLE("Uploaded {} ({}) blobs",
+ ZEN_CONSOLE("Uploaded {} ({}) blobs to {}",
StorageCacheStats.PutBlobCount.load(),
NiceBytes(StorageCacheStats.PutBlobByteCount),
- Storage.CacheName);
+ Storage.CacheHost.Name);
}
}
@@ -3828,7 +3855,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
CreateDirectories(m_ZenFolderPath);
auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); });
@@ -3883,7 +3910,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
CreateDirectories(m_ZenFolderPath);
auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); });
@@ -3933,7 +3960,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = m_Path / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
EPartialBlockRequestMode PartialBlockRequestMode = ParseAllowPartialBlockRequests();
@@ -4083,8 +4110,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
m_ZenFolderPath = m_Path / ZenFolderName;
}
- MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath);
- MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath);
+ MakeSafeAbsolutePathInPlace(m_ZenFolderPath);
+ MakeSafeAbsolutePathInPlace(m_ChunkingCachePath);
StorageInstance Storage = CreateBuildStorage(StorageStats,
StorageCacheStats,
diff --git a/src/zen/cmds/builds_cmd.h b/src/zen/cmds/builds_cmd.h
index f5c44ab55..5c80beed5 100644
--- a/src/zen/cmds/builds_cmd.h
+++ b/src/zen/cmds/builds_cmd.h
@@ -71,7 +71,7 @@ private:
bool m_AppendNewContent = false;
uint8_t m_BlockReuseMinPercentLimit = 85;
bool m_AllowMultiparts = true;
- std::string m_AllowPartialBlockRequests = "mixed";
+ std::string m_AllowPartialBlockRequests = "true";
AuthCommandLineOptions m_AuthOptions;
diff --git a/src/zen/cmds/cache_cmd.h b/src/zen/cmds/cache_cmd.h
index 4dc05bbdc..4f5b90f4d 100644
--- a/src/zen/cmds/cache_cmd.h
+++ b/src/zen/cmds/cache_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class DropCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "drop";
+ static constexpr char Description[] = "Drop cache namespace or bucket";
+
DropCommand();
~DropCommand();
@@ -16,7 +19,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"drop", "Drop cache namespace or bucket"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_NamespaceName;
std::string m_BucketName;
@@ -25,13 +28,16 @@ private:
class CacheInfoCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "cache-info";
+ static constexpr char Description[] = "Info on cache, namespace or bucket";
+
CacheInfoCommand();
~CacheInfoCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"cache-info", "Info on cache, namespace or bucket"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_NamespaceName;
std::string m_SizeInfoBucketNames;
@@ -42,26 +48,32 @@ private:
class CacheStatsCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "cache-stats";
+ static constexpr char Description[] = "Stats on cache";
+
CacheStatsCommand();
~CacheStatsCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"cache-stats", "Stats info on cache"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
};
class CacheDetailsCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "cache-details";
+ static constexpr char Description[] = "Details on cache";
+
CacheDetailsCommand();
~CacheDetailsCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"cache-details", "Detailed info on cache"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
bool m_CSV = false;
bool m_Details = false;
diff --git a/src/zen/cmds/copy_cmd.h b/src/zen/cmds/copy_cmd.h
index e1a5dcb82..757a8e691 100644
--- a/src/zen/cmds/copy_cmd.h
+++ b/src/zen/cmds/copy_cmd.h
@@ -11,6 +11,9 @@ namespace zen {
class CopyCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "copy";
+ static constexpr char Description[] = "Copy file(s)";
+
CopyCommand();
~CopyCommand();
@@ -19,7 +22,7 @@ public:
virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"copy", "Copy files efficiently"};
+ cxxopts::Options m_Options{Name, Description};
std::filesystem::path m_CopySource;
std::filesystem::path m_CopyTarget;
bool m_NoClone = false;
diff --git a/src/zen/cmds/dedup_cmd.h b/src/zen/cmds/dedup_cmd.h
index 5b8387dd2..835b35e92 100644
--- a/src/zen/cmds/dedup_cmd.h
+++ b/src/zen/cmds/dedup_cmd.h
@@ -11,6 +11,9 @@ namespace zen {
class DedupCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "dedup";
+ static constexpr char Description[] = "Dedup files";
+
DedupCommand();
~DedupCommand();
@@ -19,7 +22,7 @@ public:
virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"dedup", "Deduplicate files"};
+ cxxopts::Options m_Options{Name, Description};
std::vector<std::string> m_Positional;
std::filesystem::path m_DedupSource;
std::filesystem::path m_DedupTarget;
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp
new file mode 100644
index 000000000..42c7119e7
--- /dev/null
+++ b/src/zen/cmds/exec_cmd.cpp
@@ -0,0 +1,1374 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "exec_cmd.h"
+
+#include <zencompute/computeservice.h>
+#include <zencompute/recordingreader.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryfile.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalue.h>
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/system.h>
+#include <zencore/timer.h>
+#include <zenhttp/httpclient.h>
+#include <zenhttp/packageformat.h>
+
+#include <EASTL/hash_map.h>
+#include <EASTL/hash_set.h>
+#include <EASTL/map.h>
+
+using namespace std::literals;
+
+namespace eastl {
+
+template<>
+struct hash<zen::IoHash> : public zen::IoHash::Hasher
+{
+};
+
+} // namespace eastl
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen {
+
+ExecCommand::ExecCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), "<hosturl>");
+ m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>");
+ m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>");
+ m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>");
+ m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>");
+ m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>");
+ m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>");
+ m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>");
+ m_Options.add_option("",
+ "",
+ "mode",
+ "Select execution mode (http,inproc,dump,direct,beacon,buildlog)",
+ cxxopts::value(m_Mode)->default_value("http"),
+ "<string>");
+ m_Options
+ .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), "<bool>");
+ m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>");
+ m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>");
+ m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>");
+ m_Options.parse_positional("mode");
+}
+
+ExecCommand::~ExecCommand()
+{
+}
+
+void
+ExecCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ // Configure
+
+ if (!ParseOptions(argc, argv))
+ {
+ return;
+ }
+
+ m_HostName = ResolveTargetHostSpec(m_HostName);
+
+ if (m_RecordingPath.empty())
+ {
+ throw OptionParseException("replay path is required!", m_Options.help());
+ }
+
+ m_VerboseLogging = GlobalOptions.IsVerbose;
+ m_QuietLogging = m_Quiet && !m_VerboseLogging;
+
+ enum ExecMode
+ {
+ kHttp,
+ kDirect,
+ kInproc,
+ kDump,
+ kBeacon,
+ kBuildLog
+ } Mode;
+
+ if (m_Mode == "http"sv)
+ {
+ Mode = kHttp;
+ }
+ else if (m_Mode == "direct"sv)
+ {
+ Mode = kDirect;
+ }
+ else if (m_Mode == "inproc"sv)
+ {
+ Mode = kInproc;
+ }
+ else if (m_Mode == "dump"sv)
+ {
+ Mode = kDump;
+ }
+ else if (m_Mode == "beacon"sv)
+ {
+ Mode = kBeacon;
+ }
+ else if (m_Mode == "buildlog"sv)
+ {
+ Mode = kBuildLog;
+ }
+ else
+ {
+ throw OptionParseException("invalid mode specified!", m_Options.help());
+ }
+
+ // Gather information from recording path
+
+ std::unique_ptr<zen::compute::RecordingReader> Reader;
+ std::unique_ptr<zen::compute::UeRecordingReader> UeReader;
+
+ std::filesystem::path RecordingPath{m_RecordingPath};
+
+ if (!std::filesystem::is_directory(RecordingPath))
+ {
+ throw OptionParseException("replay path should be a directory path!", m_Options.help());
+ }
+ else
+ {
+ if (std::filesystem::is_directory(RecordingPath / "cid"))
+ {
+ Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath);
+ m_WorkerMap = Reader->ReadWorkers();
+ m_ChunkResolver = Reader.get();
+ m_RecordingReader = Reader.get();
+ }
+ else
+ {
+ UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath);
+ m_WorkerMap = UeReader->ReadWorkers();
+ m_ChunkResolver = UeReader.get();
+ m_RecordingReader = UeReader.get();
+ }
+ }
+
+ ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount());
+
+ for (auto& Kv : m_WorkerMap)
+ {
+ CbObject WorkerDesc = Kv.second.GetObject();
+ const IoHash& WorkerId = Kv.first;
+
+ RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId);
+
+ if (m_VerboseLogging)
+ {
+ zen::ExtendableStringBuilder<1024> ObjStr;
+# if 0
+ zen::CompactBinaryToJson(WorkerDesc, ObjStr);
+ ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr);
+# else
+ zen::CompactBinaryToYaml(WorkerDesc, ObjStr);
+ ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr);
+# endif
+ }
+ }
+
+ if (m_VerboseLogging)
+ {
+ EmitFunctionList(m_FunctionList);
+ }
+
+ // Iterate over work items and dispatch or log them
+
+ int ReturnValue = 0;
+
+ Stopwatch ExecTimer;
+
+ switch (Mode)
+ {
+ case kHttp:
+ // Forward requests to HTTP function service
+ ReturnValue = HttpExecute();
+ break;
+
+ case kDirect:
+ // Not currently supported
+ ReturnValue = LocalMessagingExecute();
+ break;
+
+ case kInproc:
+ // Handle execution in-core (by spawning child processes)
+ ReturnValue = InProcessExecute();
+ break;
+
+ case kDump:
+ // Dump high level information about actions to console
+ ReturnValue = DumpWorkItems();
+ break;
+
+ case kBeacon:
+ ReturnValue = BeaconExecute();
+ break;
+
+ case kBuildLog:
+ ReturnValue = BuildActionsLog();
+ break;
+
+ default:
+ ZEN_ERROR("Unknown operating mode! No work submitted");
+
+ ReturnValue = 1;
+ }
+
+ ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs()));
+
+ if (!ReturnValue)
+ {
+ ZEN_CONSOLE("all work items completed successfully");
+ }
+ else
+ {
+ ZEN_CONSOLE("some work items failed (code {})", ReturnValue);
+ }
+}
+
+int
+ExecCommand::InProcessExecute()
+{
+ ZEN_ASSERT(m_ChunkResolver);
+ ChunkResolver& Resolver = *m_ChunkResolver;
+
+ zen::compute::ComputeServiceSession ComputeSession(Resolver);
+
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+ ComputeSession.AddLocalRunner(Resolver, TempPath);
+
+ return ExecUsingSession(ComputeSession);
+}
+
+int
+ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession)
+{
+ struct JobTracker
+ {
+ public:
+ inline void Insert(int LsnField)
+ {
+ RwLock::ExclusiveLockScope _(Lock);
+ PendingJobs.insert(LsnField);
+ }
+
+ inline bool IsEmpty() const
+ {
+ RwLock::SharedLockScope _(Lock);
+ return PendingJobs.empty();
+ }
+
+ inline void Remove(int CompleteLsn)
+ {
+ RwLock::ExclusiveLockScope _(Lock);
+ PendingJobs.erase(CompleteLsn);
+ }
+
+ inline size_t GetSize() const
+ {
+ RwLock::SharedLockScope _(Lock);
+ return PendingJobs.size();
+ }
+
+ private:
+ mutable RwLock Lock;
+ std::unordered_set<int> PendingJobs;
+ };
+
+ JobTracker PendingJobs;
+
+ struct ActionSummaryEntry
+ {
+ int32_t Lsn = 0;
+ int RecordingIndex = 0;
+ IoHash ActionId;
+ std::string FunctionName;
+ int InputAttachments = 0;
+ uint64_t InputBytes = 0;
+ int OutputAttachments = 0;
+ uint64_t OutputBytes = 0;
+ float WallSeconds = 0.0f;
+ float CpuSeconds = 0.0f;
+ uint64_t SubmittedTicks = 0;
+ uint64_t StartedTicks = 0;
+ std::string ExecutionLocation;
+ };
+
+ std::mutex SummaryLock;
+ std::unordered_map<int32_t, ActionSummaryEntry> SummaryEntries;
+
+ ComputeSession.WaitUntilReady();
+
+ // Register as a client with the orchestrator (best-effort)
+
+ std::string OrchestratorClientId;
+
+ if (!m_OrchestratorUrl.empty())
+ {
+ try
+ {
+ HttpClient OrchestratorClient(m_OrchestratorUrl);
+
+ CbObjectWriter Ann;
+ Ann << "session_id"sv << GetSessionId();
+ Ann << "hostname"sv << std::string_view(GetMachineName());
+
+ CbObjectWriter Meta;
+ Meta << "source"sv
+ << "zen-exec"sv;
+ Ann << "metadata"sv << Meta.Save();
+
+ auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save());
+ if (Resp.IsSuccess())
+ {
+ OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString());
+ ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId);
+ }
+ else
+ {
+ ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode));
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("failed to register with orchestrator: {}", Ex.what());
+ }
+ }
+
+ Stopwatch OrchestratorHeartbeatTimer;
+
+ auto SendOrchestratorHeartbeat = [&] {
+ if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000)
+ {
+ return;
+ }
+ OrchestratorHeartbeatTimer.Reset();
+ try
+ {
+ HttpClient OrchestratorClient(m_OrchestratorUrl);
+ std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId));
+ }
+ catch (...)
+ {
+ }
+ };
+
+ auto ClientCleanup = MakeGuard([&] {
+ if (!OrchestratorClientId.empty())
+ {
+ try
+ {
+ HttpClient OrchestratorClient(m_OrchestratorUrl);
+ std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId));
+ }
+ catch (...)
+ {
+ }
+ }
+ });
+
+ // Create a queue to group all actions from this exec session
+
+ CbObjectWriter Metadata;
+ Metadata << "source"sv
+ << "zen-exec"sv;
+
+ auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save());
+ const int QueueId = QueueResult.QueueId;
+ if (!QueueId)
+ {
+ ZEN_ERROR("failed to create compute queue");
+ return 1;
+ }
+
+ auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); });
+
+ if (!m_OutputPath.empty())
+ {
+ zen::CreateDirectories(m_OutputPath);
+ }
+
+ std::atomic<int> IsDraining{0};
+
+ auto DrainCompletedJobs = [&] {
+ if (IsDraining.exchange(1))
+ {
+ return;
+ }
+
+ auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); });
+
+ CbObjectWriter Cbo;
+ ComputeSession.GetQueueCompleted(QueueId, Cbo);
+
+ if (CbObject Completed = Cbo.Save())
+ {
+ for (auto& It : Completed["completed"sv])
+ {
+ int32_t CompleteLsn = It.AsInt32();
+
+ CbPackage ResultPackage;
+ HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage);
+
+ if (Response == HttpResponseCode::OK)
+ {
+ if (!m_OutputPath.empty() && ResultPackage)
+ {
+ int OutputAttachments = 0;
+ uint64_t OutputBytes = 0;
+
+ if (!m_Binary)
+ {
+ // Write the root object as YAML
+ ExtendableStringBuilder<4096> YamlStr;
+ CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr);
+
+ std::string_view Yaml = YamlStr;
+ zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn),
+ IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size()));
+
+ // Write decompressed attachments
+ auto Attachments = ResultPackage.GetAttachments();
+
+ if (!Attachments.empty())
+ {
+ std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn);
+ zen::CreateDirectories(AttDir);
+
+ for (const CbAttachment& Att : Attachments)
+ {
+ ++OutputAttachments;
+
+ IoHash AttHash = Att.GetHash();
+
+ if (Att.IsCompressedBinary())
+ {
+ SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress();
+ OutputBytes += Decompressed.GetSize();
+ zen::WriteFile(AttDir / AttHash.ToHexString(),
+ IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize()));
+ }
+ else
+ {
+ SharedBuffer Binary = Att.AsBinary();
+ OutputBytes += Binary.GetSize();
+ zen::WriteFile(AttDir / AttHash.ToHexString(),
+ IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize()));
+ }
+ }
+ }
+
+ if (!m_QuietLogging)
+ {
+ ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)",
+ m_OutputPath.string(),
+ CompleteLsn,
+ OutputAttachments);
+ }
+ }
+ else
+ {
+ CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage);
+ zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized));
+
+ for (const CbAttachment& Att : ResultPackage.GetAttachments())
+ {
+ ++OutputAttachments;
+ OutputBytes += Att.AsBinary().GetSize();
+ }
+
+ if (!m_QuietLogging)
+ {
+ ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn);
+ }
+ }
+
+ std::lock_guard Lock(SummaryLock);
+ if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end())
+ {
+ It2->second.OutputAttachments = OutputAttachments;
+ It2->second.OutputBytes = OutputBytes;
+ }
+ }
+
+ PendingJobs.Remove(CompleteLsn);
+
+ ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize());
+ }
+ }
+ }
+ };
+
+ // Describe workers
+
+ ZEN_CONSOLE("describing {} workers", m_WorkerMap.size());
+
+ for (auto Kv : m_WorkerMap)
+ {
+ CbPackage WorkerDesc = Kv.second;
+
+ ComputeSession.RegisterWorker(WorkerDesc);
+ }
+
+ // Then submit work items
+
+ int FailedWorkCounter = 0;
+ size_t RemainingWorkItems = m_RecordingReader->GetActionCount();
+ int SubmittedWorkItems = 0;
+
+ ZEN_CONSOLE("submitting {} work items", RemainingWorkItems);
+
+ int OffsetCounter = m_Offset;
+ int StrideCounter = m_Stride;
+
+ auto ShouldSchedule = [&]() -> bool {
+ if (m_Limit && SubmittedWorkItems >= m_Limit)
+ {
+ // Limit reached, ignore
+
+ return false;
+ }
+
+ if (OffsetCounter && OffsetCounter--)
+ {
+ // Still in offset, ignore
+
+ return false;
+ }
+
+ if (--StrideCounter == 0)
+ {
+ StrideCounter = m_Stride;
+
+ return true;
+ }
+
+ return false;
+ };
+
+ int TargetParallelism = 8;
+
+ if (OffsetCounter || StrideCounter || m_Limit)
+ {
+ TargetParallelism = 1;
+ }
+
+ std::atomic<int> RecordingIndex{0};
+
+ m_RecordingReader->IterateActions(
+ [&](CbObject ActionObject, const IoHash& ActionId) {
+ // Enqueue job
+
+ const int CurrentRecordingIndex = RecordingIndex++;
+
+ Stopwatch SubmitTimer;
+
+ const int Priority = 0;
+
+ if (ShouldSchedule())
+ {
+ if (m_VerboseLogging)
+ {
+ int AttachmentCount = 0;
+ uint64_t AttachmentBytes = 0;
+ eastl::hash_set<IoHash> ReferencedChunks;
+
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachData = Field.AsAttachment();
+
+ ReferencedChunks.insert(AttachData);
+ ++AttachmentCount;
+
+ if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData))
+ {
+ AttachmentBytes += ChunkData.GetSize();
+ }
+ });
+
+ zen::ExtendableStringBuilder<1024> ObjStr;
+ zen::CompactBinaryToJson(ActionObject, ObjStr);
+ ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}",
+ ActionId,
+ AttachmentCount,
+ NiceBytes(AttachmentBytes),
+ ObjStr);
+ }
+
+ if (m_DumpActions)
+ {
+ int AttachmentCount = 0;
+ uint64_t AttachmentBytes = 0;
+
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachData = Field.AsAttachment();
+
+ ++AttachmentCount;
+
+ if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData))
+ {
+ AttachmentBytes += ChunkData.GetSize();
+ }
+ });
+
+ zen::ExtendableStringBuilder<1024> ObjStr;
+ zen::CompactBinaryToYaml(ActionObject, ObjStr);
+ ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr);
+ }
+
+ if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult =
+ ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority))
+ {
+ const int32_t LsnField = EnqueueResult.Lsn;
+
+ --RemainingWorkItems;
+ ++SubmittedWorkItems;
+
+ if (!m_QuietLogging)
+ {
+ ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining",
+ SubmittedWorkItems,
+ LsnField,
+ NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
+ RemainingWorkItems);
+ }
+
+ if (!m_OutputPath.empty())
+ {
+ ActionSummaryEntry Entry;
+ Entry.Lsn = LsnField;
+ Entry.RecordingIndex = CurrentRecordingIndex;
+ Entry.ActionId = ActionId;
+ Entry.FunctionName = std::string(ActionObject["Function"sv].AsString());
+
+ if (!m_Binary)
+ {
+ // Write action object as YAML
+ ExtendableStringBuilder<4096> YamlStr;
+ CompactBinaryToYaml(ActionObject, YamlStr);
+
+ std::string_view Yaml = YamlStr;
+ zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField),
+ IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size()));
+
+ // Write decompressed input attachments
+ std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField);
+ bool AttDirCreated = false;
+
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachCid = Field.AsAttachment();
+ ++Entry.InputAttachments;
+
+ if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid))
+ {
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize);
+ SharedBuffer Decompressed = Compressed.Decompress();
+
+ Entry.InputBytes += Decompressed.GetSize();
+
+ if (!AttDirCreated)
+ {
+ zen::CreateDirectories(AttDir);
+ AttDirCreated = true;
+ }
+
+ zen::WriteFile(AttDir / AttachCid.ToHexString(),
+ IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize()));
+ }
+ });
+
+ if (!m_QuietLogging)
+ {
+ ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)",
+ m_OutputPath.string(),
+ LsnField,
+ Entry.InputAttachments);
+ }
+ }
+ else
+ {
+ // Build a CbPackage from the action and write as .pkg
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionObject);
+
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachCid = Field.AsAttachment();
+ ++Entry.InputAttachments;
+
+ if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid))
+ {
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize);
+
+ Entry.InputBytes += ChunkData.GetSize();
+ ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash));
+ }
+ });
+
+ CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage);
+ zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized));
+
+ if (!m_QuietLogging)
+ {
+ ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField);
+ }
+ }
+
+ std::lock_guard Lock(SummaryLock);
+ SummaryEntries.emplace(LsnField, std::move(Entry));
+ }
+
+ PendingJobs.Insert(LsnField);
+ }
+ else
+ {
+ if (!m_QuietLogging)
+ {
+ std::string_view FunctionName = ActionObject["Function"sv].AsString();
+ const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid();
+ const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid();
+
+ ZEN_ERROR(
+ "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work "
+ "descriptor "
+ "at: 'file://{}'",
+ std::string(FunctionName),
+ FunctionVersion,
+ BuildSystemVersion,
+ "<null>");
+
+ EmitFunctionListOnce(m_FunctionList);
+ }
+
+ ++FailedWorkCounter;
+ }
+ }
+
+ // Check for completed work
+
+ DrainCompletedJobs();
+ SendOrchestratorHeartbeat();
+ },
+ TargetParallelism);
+
+ // Wait until all pending work is complete
+
+ while (!PendingJobs.IsEmpty())
+ {
+ // TODO: improve this logic
+ zen::Sleep(500);
+
+ DrainCompletedJobs();
+ SendOrchestratorHeartbeat();
+ }
+
+ // Merge timing data from queue history into summary entries
+
+ if (!SummaryEntries.empty())
+ {
+ // RunnerAction::State indices (can't include functionrunner.h from here)
+ constexpr int kStateNew = 0;
+ constexpr int kStatePending = 1;
+ constexpr int kStateRunning = 3;
+ constexpr int kStateCompleted = 4; // first terminal state
+ constexpr int kStateCount = 8;
+
+ for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0))
+ {
+ std::lock_guard Lock(SummaryLock);
+ if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end())
+ {
+ // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled)
+ uint64_t EndTick = 0;
+ for (int S = kStateCompleted; S < kStateCount; ++S)
+ {
+ if (HistEntry.Timestamps[S] != 0)
+ {
+ EndTick = HistEntry.Timestamps[S];
+ break;
+ }
+ }
+ uint64_t StartTick = HistEntry.Timestamps[kStateNew];
+ if (EndTick > StartTick)
+ {
+ It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond));
+ }
+ It->second.CpuSeconds = HistEntry.CpuSeconds;
+ It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending];
+ It->second.StartedTicks = HistEntry.Timestamps[kStateRunning];
+ It->second.ExecutionLocation = HistEntry.ExecutionLocation;
+ }
+ }
+ }
+
+ // Write summary file if output path is set
+
+ if (!m_OutputPath.empty() && !SummaryEntries.empty())
+ {
+ std::vector<ActionSummaryEntry> Sorted;
+ Sorted.reserve(SummaryEntries.size());
+ for (auto& [_, Entry] : SummaryEntries)
+ {
+ Sorted.push_back(std::move(Entry));
+ }
+
+ std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) {
+ return A.RecordingIndex < B.RecordingIndex;
+ });
+
+ auto FormatTimestamp = [](uint64_t Ticks) -> std::string {
+ if (Ticks == 0)
+ {
+ return "-";
+ }
+ return DateTime(Ticks).ToString("%H:%M:%S.%s");
+ };
+
+ ExtendableStringBuilder<4096> Summary;
+ Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n",
+ "LSN",
+ "Index",
+ "ActionId",
+ "Function",
+ "InAtt",
+ "InBytes",
+ "OutAtt",
+ "OutBytes",
+ "Wall(s)",
+ "CPU(s)",
+ "Submitted",
+ "Started",
+ "Location"));
+ Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""));
+
+ for (const ActionSummaryEntry& Entry : Sorted)
+ {
+ Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n",
+ Entry.Lsn,
+ Entry.RecordingIndex,
+ Entry.ActionId,
+ Entry.FunctionName,
+ Entry.InputAttachments,
+ NiceBytes(Entry.InputBytes),
+ Entry.OutputAttachments,
+ NiceBytes(Entry.OutputBytes),
+ Entry.WallSeconds,
+ Entry.CpuSeconds,
+ FormatTimestamp(Entry.SubmittedTicks),
+ FormatTimestamp(Entry.StartedTicks),
+ Entry.ExecutionLocation));
+ }
+
+ std::filesystem::path SummaryPath = m_OutputPath / "summary.txt";
+ std::string_view SummaryStr = Summary;
+ zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size()));
+
+ ZEN_CONSOLE("wrote summary to {}", SummaryPath.string());
+
+ if (!m_Binary)
+ {
+ auto EscapeHtml = [](std::string_view Input) -> std::string {
+ std::string Out;
+ Out.reserve(Input.size());
+ for (char C : Input)
+ {
+ switch (C)
+ {
+ case '&':
+ Out += "&amp;";
+ break;
+ case '<':
+ Out += "&lt;";
+ break;
+ case '>':
+ Out += "&gt;";
+ break;
+ case '"':
+ Out += "&quot;";
+ break;
+ case '\'':
+ Out += "&#39;";
+ break;
+ default:
+ Out += C;
+ }
+ }
+ return Out;
+ };
+
+ auto EscapeJson = [](std::string_view Input) -> std::string {
+ std::string Out;
+ Out.reserve(Input.size());
+ for (char C : Input)
+ {
+ switch (C)
+ {
+ case '"':
+ Out += "\\\"";
+ break;
+ case '\\':
+ Out += "\\\\";
+ break;
+ case '\n':
+ Out += "\\n";
+ break;
+ case '\r':
+ Out += "\\r";
+ break;
+ case '\t':
+ Out += "\\t";
+ break;
+ default:
+ if (static_cast<unsigned char>(C) < 0x20)
+ {
+ Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C)));
+ }
+ else
+ {
+ Out += C;
+ }
+ }
+ }
+ return Out;
+ };
+
+ ExtendableStringBuilder<8192> Html;
+
+ Html.Append(std::string_view(R"(<!DOCTYPE html>
+<html><head><meta charset="utf-8"><title>Exec Summary</title>
+<style>
+body{font-family:system-ui,sans-serif;margin:20px;background:#fafafa}
+#container{overflow-y:auto;height:calc(100vh - 120px)}
+table{border-collapse:collapse;width:100%}
+th,td{border:1px solid #ddd;padding:6px 10px;text-align:left;white-space:nowrap}
+th{background:#f0f0f0;cursor:pointer;user-select:none;position:sticky;top:0;z-index:1}
+th:hover{background:#e0e0e0}
+th .arrow{font-size:0.7em;margin-left:4px}
+tr:hover{background:#e8f0fe}
+input{padding:6px 10px;margin-bottom:12px;width:300px;border:1px solid #ccc;border-radius:4px}
+button{padding:6px 14px;margin-left:8px;margin-bottom:12px;border:1px solid #ccc;border-radius:4px;background:#f0f0f0;cursor:pointer}
+button:hover{background:#e0e0e0}
+a{color:#1a73e8;text-decoration:none}
+a:hover{text-decoration:underline}
+.num{text-align:right}
+</style></head><body>
+<h2>Exec Summary</h2>
+<input type="text" id="filter" placeholder="Filter by function name..."><button id="csvBtn">Export CSV</button>
+<div id="container">
+<table><thead><tr>
+<th data-col="0">LSN <span class="arrow"></span></th>
+<th data-col="1">Index <span class="arrow"></span></th>
+<th data-col="2">Action ID <span class="arrow"></span></th>
+<th data-col="3">Function <span class="arrow"></span></th>
+<th data-col="4">In Attachments <span class="arrow"></span></th>
+<th data-col="5">In Bytes <span class="arrow"></span></th>
+<th data-col="6">Out Attachments <span class="arrow"></span></th>
+<th data-col="7">Out Bytes <span class="arrow"></span></th>
+<th data-col="8">Wall(s) <span class="arrow"></span></th>
+<th data-col="9">CPU(s) <span class="arrow"></span></th>
+<th data-col="10">Submitted <span class="arrow"></span></th>
+<th data-col="11">Started <span class="arrow"></span></th>
+<th data-col="12">Location <span class="arrow"></span></th>
+</tr></thead><tbody>
+<tr id="spacerTop"><td colspan="13"></td></tr>
+<tr id="spacerBot"><td colspan="13"></td></tr>
+</tbody></table></div>
+<script>
+const DATA=[
+)"));
+
+ std::string_view ResultExt = ".result.yaml";
+ std::string_view ActionExt = ".action.yaml";
+
+ for (const ActionSummaryEntry& Entry : Sorted)
+ {
+ std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName));
+ std::string ActionIdStr = Entry.ActionId.ToHexString();
+ std::string ActionLink;
+ if (!ActionExt.empty())
+ {
+ ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt));
+ }
+
+ // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes,
+ // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink,
+ // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay,
+ // 17=location
+ Html.Append(
+ fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n",
+ Entry.Lsn,
+ Entry.RecordingIndex,
+ ActionIdStr,
+ SafeName,
+ Entry.InputAttachments,
+ Entry.InputBytes,
+ Entry.OutputAttachments,
+ Entry.OutputBytes,
+ Entry.WallSeconds,
+ Entry.CpuSeconds,
+ EscapeJson(NiceBytes(Entry.InputBytes)),
+ EscapeJson(NiceBytes(Entry.OutputBytes)),
+ ActionLink,
+ Entry.SubmittedTicks,
+ Entry.StartedTicks,
+ FormatTimestamp(Entry.SubmittedTicks),
+ FormatTimestamp(Entry.StartedTicks),
+ EscapeJson(EscapeHtml(Entry.ExecutionLocation))));
+ }
+
+ Html.Append(fmt::format(R"(];
+const RESULT_EXT="{}";
+)",
+ ResultExt));
+
+ Html.Append(std::string_view(R"JS((function(){
+const ROW_H=33,BUF=20;
+const container=document.getElementById("container");
+const tbody=container.querySelector("tbody");
+const headers=container.querySelectorAll("th");
+const filterInput=document.getElementById("filter");
+const spacerTop=document.getElementById("spacerTop");
+const spacerBot=document.getElementById("spacerBot");
+let view=[...DATA.keys()];
+let sortCol=-1,sortAsc=true;
+const COLS=[
+ {f:0,t:"n"},{f:1,t:"n"},{f:2,t:"s"},{f:3,t:"s"},
+ {f:4,t:"n"},{f:5,t:"n"},{f:6,t:"n"},{f:7,t:"n"},
+ {f:8,t:"n"},{f:9,t:"n"},{f:13,t:"n"},{f:14,t:"n"},{f:17,t:"s"}
+];
+function rowHtml(i){
+ const d=DATA[view[i]];
+ const bg=i%2?' style="background:#f9f9f9"':'';
+ return '<tr'+bg+'>'+
+ '<td class="num"><a href="'+d[0]+RESULT_EXT+'">'+d[0]+'</a></td>'+
+ '<td class="num">'+d[1]+'</td>'+
+ '<td><code>'+d[2]+'</code></td>'+
+ '<td>'+d[3]+d[12]+'</td>'+
+ '<td class="num">'+d[4]+'</td>'+
+ '<td class="num">'+d[10]+'</td>'+
+ '<td class="num">'+d[6]+'</td>'+
+ '<td class="num">'+d[11]+'</td>'+
+ '<td class="num">'+d[8].toFixed(2)+'</td>'+
+ '<td class="num">'+d[9].toFixed(2)+'</td>'+
+ '<td class="num">'+d[15]+'</td>'+
+ '<td class="num">'+d[16]+'</td>'+
+ '<td>'+d[17]+'</td></tr>';
+}
+let lastFirst=-1,lastLast=-1;
+function render(){
+ const scrollTop=container.scrollTop;
+ const viewH=container.clientHeight;
+ let first=Math.floor(scrollTop/ROW_H)-BUF;
+ let last=Math.ceil((scrollTop+viewH)/ROW_H)+BUF;
+ if(first<0) first=0;
+ if(last>=view.length) last=view.length-1;
+ if(first===lastFirst&&last===lastLast) return;
+ lastFirst=first;lastLast=last;
+ const rows=[];
+ for(let i=first;i<=last;i++) rows.push(rowHtml(i));
+ spacerTop.style.height=(first*ROW_H)+'px';
+ spacerBot.style.height=((view.length-1-last)*ROW_H)+'px';
+ const mid=rows.join('');
+ const topTr='<tr id="spacerTop"><td colspan="13" style="border:0;padding:0;height:'+spacerTop.style.height+'"></td></tr>';
+ const botTr='<tr id="spacerBot"><td colspan="13" style="border:0;padding:0;height:'+spacerBot.style.height+'"></td></tr>';
+ tbody.innerHTML=topTr+mid+botTr;
+}
+function applySort(){
+ if(sortCol<0) return;
+ const c=COLS[sortCol];
+ view.sort((a,b)=>{
+ const va=DATA[a][c.f],vb=DATA[b][c.f];
+ if(c.t==="n") return sortAsc?va-vb:vb-va;
+ return sortAsc?(va<vb?-1:va>vb?1:0):(va>vb?-1:va<vb?1:0);
+ });
+}
+function rebuild(){
+ const q=filterInput.value.toLowerCase();
+ view=[];
+ for(let i=0;i<DATA.length;i++){
+ if(!q||DATA[i][3].toLowerCase().includes(q)) view.push(i);
+ }
+ applySort();
+ lastFirst=lastLast=-1;
+ render();
+}
+headers.forEach(th=>{
+ th.addEventListener("click",()=>{
+ const col=parseInt(th.dataset.col);
+ if(sortCol===col){sortAsc=!sortAsc}else{sortCol=col;sortAsc=true}
+ headers.forEach(h=>h.querySelector(".arrow").textContent="");
+ th.querySelector(".arrow").textContent=sortAsc?"\u25B2":"\u25BC";
+ applySort();
+ lastFirst=lastLast=-1;
+ render();
+ });
+});
+filterInput.addEventListener("input",()=>rebuild());
+let ticking=false;
+container.addEventListener("scroll",()=>{
+ if(!ticking){ticking=true;requestAnimationFrame(()=>{render();ticking=false})}
+});
+rebuild();
+document.getElementById("csvBtn").addEventListener("click",()=>{
+ const H=["LSN","Index","Action ID","Function","In Attachments","In Bytes","Out Attachments","Out Bytes","Wall(s)","CPU(s)","Submitted","Started","Location"];
+ const esc=v=>{const s=String(v);return s.includes(',')||s.includes('"')||s.includes('\n')?'"'+s.replace(/"/g,'""')+'"':s};
+ const rows=[H.join(",")];
+ for(let i=0;i<view.length;i++){
+ const d=DATA[view[i]];
+ rows.push([d[0],d[1],d[2],d[3],d[4],d[5],d[6],d[7],d[8],d[9],d[15],d[16],d[17]].map(esc).join(","));
+ }
+ const blob=new Blob([rows.join("\n")],{type:"text/csv"});
+ const a=document.createElement("a");
+ a.href=URL.createObjectURL(blob);
+ a.download="summary.csv";
+ a.click();
+ URL.revokeObjectURL(a.href);
+});
+})();
+</script></body></html>
+)JS"));
+
+ std::filesystem::path HtmlPath = m_OutputPath / "summary.html";
+ std::string_view HtmlStr = Html;
+ zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size()));
+
+ ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string());
+ }
+ }
+
+ if (FailedWorkCounter)
+ {
+ return 1;
+ }
+
+ return 0;
+}
+
+int
+ExecCommand::LocalMessagingExecute()
+{
+ // Non-HTTP work submission path
+
+ // To be reimplemented using final transport
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+int
+ExecCommand::HttpExecute()
+{
+ ZEN_ASSERT(m_ChunkResolver);
+ ChunkResolver& Resolver = *m_ChunkResolver;
+
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+
+ zen::compute::ComputeServiceSession ComputeSession(Resolver);
+ ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName);
+
+ return ExecUsingSession(ComputeSession);
+}
+
+int
+ExecCommand::BeaconExecute()
+{
+ ZEN_ASSERT(m_ChunkResolver);
+ ChunkResolver& Resolver = *m_ChunkResolver;
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+
+ zen::compute::ComputeServiceSession ComputeSession(Resolver);
+
+ if (!m_OrchestratorUrl.empty())
+ {
+ ZEN_CONSOLE_INFO("using orchestrator at {}", m_OrchestratorUrl);
+ ComputeSession.SetOrchestratorEndpoint(m_OrchestratorUrl);
+ ComputeSession.SetOrchestratorBasePath(TempPath);
+ }
+ else
+ {
+ ZEN_CONSOLE_INFO("note: using hard-coded local worker path");
+ ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558");
+ }
+
+ return ExecUsingSession(ComputeSession);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId)
+{
+ const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid();
+
+ for (auto& Item : WorkerDesc["functions"sv])
+ {
+ CbObjectView Function = Item.AsObjectView();
+
+ std::string_view FunctionName = Function["name"sv].AsString();
+ const Guid FunctionVersion = Function["version"sv].AsUuid();
+
+ m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName},
+ .FunctionVersion = FunctionVersion,
+ .BuildSystemVersion = WorkerBuildSystemVersion,
+ .WorkerId = WorkerId});
+ }
+}
+
+void
+ExecCommand::EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList)
+{
+ if (m_FunctionListEmittedOnce == false)
+ {
+ EmitFunctionList(FunctionList);
+
+ m_FunctionListEmittedOnce = true;
+ }
+}
+
+int
+ExecCommand::DumpWorkItems()
+{
+ std::atomic<int> EmittedCount{0};
+
+ eastl::hash_map<IoHash, uint64_t> SeenAttachments; // Attachment CID -> count of references
+
+ m_RecordingReader->IterateActions(
+ [&](CbObject ActionObject, const IoHash& ActionId) {
+ eastl::hash_map<IoHash, CompressedBuffer> Attachments;
+
+ uint64_t AttachmentBytes = 0;
+ uint64_t UncompressedAttachmentBytes = 0;
+
+ ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) {
+ const IoHash AttachmentCid = AttachmentField.GetValue().AsHash();
+ IoBuffer AttachmentData = m_ChunkResolver->FindChunkByCid(AttachmentCid);
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize);
+ Attachments[AttachmentCid] = CompressedData;
+
+ AttachmentBytes += CompressedData.GetCompressedSize();
+ UncompressedAttachmentBytes += CompressedData.DecodeRawSize();
+
+ if (auto [Iter, Inserted] = SeenAttachments.insert({AttachmentCid, 1}); !Inserted)
+ {
+ ++Iter->second;
+ }
+ });
+
+ zen::ExtendableStringBuilder<1024> ObjStr;
+
+# if 0
+ zen::CompactBinaryToJson(ActionObject, ObjStr);
+ ZEN_CONSOLE("work item {} ({} attachments): {}", ActionId, Attachments.size(), ObjStr);
+# else
+ zen::CompactBinaryToYaml(ActionObject, ObjStr);
+ ZEN_CONSOLE("work item {} ({} attachments, {}->{} bytes):\n{}",
+ ActionId,
+ Attachments.size(),
+ AttachmentBytes,
+ UncompressedAttachmentBytes,
+ ObjStr);
+# endif
+
+ ++EmittedCount;
+ },
+ 1);
+
+ ZEN_CONSOLE("emitted: {} actions", EmittedCount.load());
+
+ eastl::map<uint64_t, std::vector<IoHash>> ReferenceHistogram;
+
+ for (const auto& [K, V] : SeenAttachments)
+ {
+ if (V > 1)
+ {
+ ReferenceHistogram[V].push_back(K);
+ }
+ }
+
+ for (const auto& [RefCount, Cids] : ReferenceHistogram)
+ {
+ ZEN_CONSOLE("{} attachments with {} references", Cids.size(), RefCount);
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+int
+ExecCommand::BuildActionsLog()
+{
+ ZEN_ASSERT(m_ChunkResolver);
+ ChunkResolver& Resolver = *m_ChunkResolver;
+
+ if (m_RecordingPath.empty())
+ {
+ throw OptionParseException("need to specify recording path", m_Options.help());
+ }
+
+ if (std::filesystem::exists(m_RecordingLogPath))
+ {
+ throw OptionParseException(fmt::format("recording log directory '{}' already exists!", m_RecordingLogPath), m_Options.help());
+ }
+
+ ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!");
+
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+
+ zen::compute::ComputeServiceSession ComputeSession(Resolver);
+ ComputeSession.StartRecording(Resolver, m_RecordingLogPath);
+
+ return ExecUsingSession(ComputeSession);
+}
+
+void
+ExecCommand::EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList)
+{
+ ZEN_CONSOLE("=== Known functions:\n===========================");
+
+ ZEN_CONSOLE("{:30} {:36} {:36} {}", "function", "version", "build system", "worker id");
+
+ for (const FunctionDefinition& Func : FunctionList)
+ {
+ ZEN_CONSOLE("{:30} {:36} {:36} {}", Func.FunctionName, Func.FunctionVersion, Func.BuildSystemVersion, Func.WorkerId);
+ }
+
+ ZEN_CONSOLE("===========================");
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h
new file mode 100644
index 000000000..6311354c0
--- /dev/null
+++ b/src/zen/cmds/exec_cmd.h
@@ -0,0 +1,101 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+#include <zencompute/recordingreader.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/guid.h>
+#include <zencore/iohash.h>
+
+#include <filesystem>
+#include <functional>
+#include <unordered_map>
+
+namespace zen {
+class CbPackage;
+class CbObject;
+struct IoHash;
+class ChunkResolver;
+} // namespace zen
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen::compute {
+class ComputeServiceSession;
+}
+
+namespace zen {
+
+/**
+ * Zen CLI command for executing functions from a recording
+ *
+ * Mostly for testing and debugging purposes
+ */
+
+class ExecCommand : public ZenCmdBase
+{
+public:
+ ExecCommand();
+ ~ExecCommand();
+
+ static constexpr char Name[] = "exec";
+ static constexpr char Description[] = "Execute functions from a recording";
+
+ virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{Name, Description};
+ std::string m_HostName;
+ std::string m_OrchestratorUrl;
+ std::filesystem::path m_BeaconPath;
+ std::filesystem::path m_RecordingPath;
+ std::filesystem::path m_RecordingLogPath;
+ int m_Offset = 0;
+ int m_Stride = 1;
+ int m_Limit = 0;
+ bool m_Quiet = false;
+ std::string m_Mode{"http"};
+ std::filesystem::path m_OutputPath;
+ bool m_Binary = false;
+
+ struct FunctionDefinition
+ {
+ std::string FunctionName;
+ zen::Guid FunctionVersion;
+ zen::Guid BuildSystemVersion;
+ zen::IoHash WorkerId;
+ };
+
+ bool m_FunctionListEmittedOnce = false;
+ void EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList);
+ void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList);
+
+ std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap;
+ std::vector<FunctionDefinition> m_FunctionList;
+ bool m_VerboseLogging = false;
+ bool m_QuietLogging = false;
+ bool m_DumpActions = false;
+
+ zen::ChunkResolver* m_ChunkResolver = nullptr;
+ zen::compute::RecordingReaderBase* m_RecordingReader = nullptr;
+
+ void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId);
+
+ int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession);
+
+ // Execution modes
+
+ int DumpWorkItems();
+ int HttpExecute();
+ int InProcessExecute();
+ int LocalMessagingExecute();
+ int BeaconExecute();
+ int BuildActionsLog();
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zen/cmds/info_cmd.h b/src/zen/cmds/info_cmd.h
index 231565bfd..dc108b8a2 100644
--- a/src/zen/cmds/info_cmd.h
+++ b/src/zen/cmds/info_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class InfoCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "info";
+ static constexpr char Description[] = "Show high level Zen server information";
+
InfoCommand();
~InfoCommand();
@@ -17,7 +20,7 @@ public:
// virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"info", "Show high level zen store information"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
};
diff --git a/src/zen/cmds/print_cmd.cpp b/src/zen/cmds/print_cmd.cpp
index 030cc8b66..c6b250fdf 100644
--- a/src/zen/cmds/print_cmd.cpp
+++ b/src/zen/cmds/print_cmd.cpp
@@ -84,7 +84,7 @@ PrintCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
}
else
{
- MakeSafeAbsolutePathÍnPlace(m_Filename);
+ MakeSafeAbsolutePathInPlace(m_Filename);
Fc = ReadFile(m_Filename);
}
@@ -244,7 +244,7 @@ PrintPackageCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar
if (m_Filename.empty())
throw OptionParseException("'--source' is required", m_Options.help());
- MakeSafeAbsolutePathÍnPlace(m_Filename);
+ MakeSafeAbsolutePathInPlace(m_Filename);
FileContents Fc = ReadFile(m_Filename);
IoBuffer Data = Fc.Flatten();
CbPackage Package;
diff --git a/src/zen/cmds/print_cmd.h b/src/zen/cmds/print_cmd.h
index 6c1529b7c..f4a97e218 100644
--- a/src/zen/cmds/print_cmd.h
+++ b/src/zen/cmds/print_cmd.h
@@ -11,6 +11,9 @@ namespace zen {
class PrintCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "print";
+ static constexpr char Description[] = "Print compact binary object";
+
PrintCommand();
~PrintCommand();
@@ -19,7 +22,7 @@ public:
virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"print", "Print compact binary object"};
+ cxxopts::Options m_Options{Name, Description};
std::filesystem::path m_Filename;
bool m_ShowCbObjectTypeInfo = false;
};
@@ -29,6 +32,9 @@ private:
class PrintPackageCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "printpackage";
+ static constexpr char Description[] = "Print compact binary package";
+
PrintPackageCommand();
~PrintPackageCommand();
@@ -37,7 +43,7 @@ public:
virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"printpkg", "Print compact binary package"};
+ cxxopts::Options m_Options{Name, Description};
std::filesystem::path m_Filename;
bool m_ShowCbObjectTypeInfo = false;
};
diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp
index 519b68126..db931e49a 100644
--- a/src/zen/cmds/projectstore_cmd.cpp
+++ b/src/zen/cmds/projectstore_cmd.cpp
@@ -41,12 +41,10 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-namespace {
+namespace projectstore_impl {
using namespace std::literals;
-#define ZEN_CLOUD_STORAGE "Cloud Storage"
-
void WriteAuthOptions(CbObjectWriter& Writer,
std::string_view JupiterOpenIdProvider,
std::string_view JupiterAccessToken,
@@ -500,7 +498,7 @@ namespace {
return {};
}
-} // namespace
+} // namespace projectstore_impl
///////////////////////////////////////
@@ -522,6 +520,7 @@ DropProjectCommand::~DropProjectCommand()
void
DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -611,6 +610,7 @@ ProjectInfoCommand::~ProjectInfoCommand()
void
ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -697,6 +697,7 @@ CreateProjectCommand::~CreateProjectCommand() = default;
void
CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
using namespace std::literals;
@@ -766,6 +767,7 @@ CreateOplogCommand::~CreateOplogCommand() = default;
void
CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
using namespace std::literals;
@@ -990,6 +992,7 @@ ExportOplogCommand::~ExportOplogCommand()
void
ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
using namespace std::literals;
ZEN_UNUSED(GlobalOptions);
@@ -1470,6 +1473,20 @@ ImportOplogCommand::ImportOplogCommand()
"Enables both 'boost-worker-count' and 'boost-worker-memory' - may cause computer to be less responsive",
cxxopts::value(m_BoostWorkers),
"<boostworkermemory>");
+ m_Options.add_option(
+ "",
+ "",
+ "allow-partial-block-requests",
+ "Allow request for partial chunk blocks.\n"
+ " false = only full block requests allowed\n"
+ " mixed = multiple partial block ranges requests per block allowed to zen cache, single partial block range "
+ "request per block to host\n"
+ " zencacheonly = multiple partial block ranges requests per block allowed to zen cache, only full block requests "
+ "allowed to host\n"
+ " true = multiple partial block ranges requests per block allowed to zen cache and host\n"
+ "Defaults to 'mixed'.",
+ cxxopts::value(m_AllowPartialBlockRequests),
+ "<allowpartialblockrequests>");
m_Options.parse_positional({"project", "oplog", "gcpath"});
m_Options.positional_help("[<projectid> <oplogid> [<gcpath>]]");
@@ -1482,6 +1499,7 @@ ImportOplogCommand::~ImportOplogCommand()
void
ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
using namespace std::literals;
ZEN_UNUSED(GlobalOptions);
@@ -1514,6 +1532,13 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg
throw OptionParseException("'--oplog' is required", m_Options.help());
}
+ EPartialBlockRequestMode Mode = PartialBlockRequestModeFromString(m_AllowPartialBlockRequests);
+ if (Mode == EPartialBlockRequestMode::Invalid)
+ {
+ throw OptionParseException(fmt::format("'--allow-partial-block-requests' ('{}') is invalid", m_AllowPartialBlockRequests),
+ m_Options.help());
+ }
+
HttpClient Http(m_HostName);
m_ProjectName = ResolveProject(Http, m_ProjectName);
if (m_ProjectName.empty())
@@ -1651,6 +1676,9 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg
{
Writer.AddBool("boostworkermemory"sv, true);
}
+
+ Writer.AddString("partialblockrequestmode", m_AllowPartialBlockRequests);
+
if (!m_FileDirectoryPath.empty())
{
Writer.BeginObject("file"sv);
@@ -1766,6 +1794,7 @@ SnapshotOplogCommand::~SnapshotOplogCommand()
void
SnapshotOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
using namespace std::literals;
ZEN_UNUSED(GlobalOptions);
@@ -1830,6 +1859,7 @@ ProjectStatsCommand::~ProjectStatsCommand()
void
ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -1882,6 +1912,7 @@ ProjectOpDetailsCommand::~ProjectOpDetailsCommand()
void
ProjectOpDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -1997,6 +2028,7 @@ OplogMirrorCommand::~OplogMirrorCommand()
void
OplogMirrorCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -2264,6 +2296,7 @@ OplogValidateCommand::~OplogValidateCommand()
void
OplogValidateCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -2415,6 +2448,7 @@ OplogDownloadCommand::~OplogDownloadCommand()
void
OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace projectstore_impl;
ZEN_UNUSED(GlobalOptions);
if (!ParseOptions(argc, argv))
@@ -2432,7 +2466,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
{
m_SystemRootDir = PickDefaultSystemRootDirectory();
}
- MakeSafeAbsolutePathÍnPlace(m_SystemRootDir);
+ MakeSafeAbsolutePathInPlace(m_SystemRootDir);
};
ParseSystemOptions();
@@ -2570,36 +2604,37 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
StorageInstance Storage;
- ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2;
+ ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2;
ClientSettings.MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u;
- Storage.BuildStorageHttp = std::make_unique<HttpClient>(ResolveRes.HostUrl, ClientSettings);
+ Storage.BuildStorageHttp = std::make_unique<HttpClient>(ResolveRes.Cloud.Address, ClientSettings);
+ Storage.BuildStorageHost = ResolveRes.Cloud;
BuildStorageCache::Statistics StorageCacheStats;
std::atomic<bool> AbortFlag(false);
- if (!ResolveRes.CacheUrl.empty())
+ if (!ResolveRes.Cache.Address.empty())
{
Storage.CacheHttp = std::make_unique<HttpClient>(
- ResolveRes.CacheUrl,
+ ResolveRes.Cache.Address,
HttpClientSettings{
.LogCategory = "httpcacheclient",
.ConnectTimeout = std::chrono::milliseconds{3000},
.Timeout = std::chrono::milliseconds{30000},
- .AssumeHttp2 = ResolveRes.CacheAssumeHttp2,
+ .AssumeHttp2 = ResolveRes.Cache.AssumeHttp2,
.AllowResume = true,
.RetryCount = 0,
.MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u},
[&AbortFlag]() { return AbortFlag.load(); });
- Storage.CacheName = ResolveRes.CacheName;
+ Storage.CacheHost = ResolveRes.Cache;
}
if (!m_Quiet)
{
std::string StorageDescription =
fmt::format("Cloud {}{}. SessionId {}. Namespace '{}', Bucket '{}'",
- ResolveRes.HostName,
- (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl),
+ ResolveRes.Cloud.Name,
+ (ResolveRes.Cloud.Address == ResolveRes.Cloud.Name) ? "" : fmt::format(" {}", ResolveRes.Cloud.Address),
Storage.BuildStorageHttp->GetSessionId(),
m_Namespace,
m_Bucket);
@@ -2610,8 +2645,8 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
{
std::string CacheDescription =
fmt::format("Zen {}{}. SessionId {}. Namespace '{}', Bucket '{}'",
- ResolveRes.CacheName,
- (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl),
+ ResolveRes.Cache.Name,
+ (ResolveRes.Cache.Address == ResolveRes.Cache.Name) ? "" : fmt::format(" {}", ResolveRes.Cache.Address),
Storage.CacheHttp->GetSessionId(),
m_Namespace,
m_Bucket);
@@ -2627,11 +2662,10 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
Storage.BuildStorage =
CreateJupiterBuildStorage(Log(), *Storage.BuildStorageHttp, StorageStats, m_Namespace, m_Bucket, m_AllowRedirect, StorageTempPath);
- Storage.StorageName = ResolveRes.HostName;
if (Storage.CacheHttp)
{
- Storage.BuildCacheStorage = CreateZenBuildStorageCache(
+ Storage.CacheStorage = CreateZenBuildStorageCache(
*Storage.CacheHttp,
StorageCacheStats,
m_Namespace,
diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h
index 56ef858f5..1ba98b39e 100644
--- a/src/zen/cmds/projectstore_cmd.h
+++ b/src/zen/cmds/projectstore_cmd.h
@@ -16,6 +16,9 @@ class ProjectStoreCommand : public ZenCmdBase
class DropProjectCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "project-drop";
+ static constexpr char Description[] = "Drop project or project oplog";
+
DropProjectCommand();
~DropProjectCommand();
@@ -23,7 +26,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"project-drop", "Drop project or project oplog"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
@@ -33,13 +36,16 @@ private:
class ProjectInfoCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "project-info";
+ static constexpr char Description[] = "Info on project or project oplog";
+
ProjectInfoCommand();
~ProjectInfoCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"project-info", "Info on project or project oplog"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
@@ -48,6 +54,9 @@ private:
class CreateProjectCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "project-create";
+ static constexpr char Description[] = "Create a project";
+
CreateProjectCommand();
~CreateProjectCommand();
@@ -55,7 +64,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"project-create", "Create project, the project must not already exist."};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectId;
std::string m_RootDir;
@@ -68,6 +77,9 @@ private:
class CreateOplogCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "oplog-create";
+ static constexpr char Description[] = "Create a project oplog";
+
CreateOplogCommand();
~CreateOplogCommand();
@@ -75,7 +87,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"oplog-create", "Create oplog in an existing project, the oplog must not already exist."};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectId;
std::string m_OplogId;
@@ -86,6 +98,9 @@ private:
class ExportOplogCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "oplog-export";
+ static constexpr char Description[] = "Export project store oplog";
+
ExportOplogCommand();
~ExportOplogCommand();
@@ -93,8 +108,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"oplog-export",
- "Export project store oplog to cloud (--cloud), file system (--file) or other Zen instance (--zen)"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
@@ -145,6 +159,9 @@ private:
class ImportOplogCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "oplog-import";
+ static constexpr char Description[] = "Import project store oplog";
+
ImportOplogCommand();
~ImportOplogCommand();
@@ -152,8 +169,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"oplog-import",
- "Import project store oplog from cloud (--cloud), file system (--file) or other Zen instance (--zen)"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
@@ -193,19 +209,23 @@ private:
bool m_BoostWorkerCount = false;
bool m_BoostWorkerMemory = false;
bool m_BoostWorkers = false;
+
+ std::string m_AllowPartialBlockRequests = "true";
};
class SnapshotOplogCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "oplog-snapshot";
+ static constexpr char Description[] = "Snapshot project store oplog";
+
SnapshotOplogCommand();
~SnapshotOplogCommand();
-
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"oplog-snapshot", "Snapshot external file references in project store oplog into zen"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
@@ -214,26 +234,32 @@ private:
class ProjectStatsCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "project-stats";
+ static constexpr char Description[] = "Stats on project store";
+
ProjectStatsCommand();
~ProjectStatsCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"project-stats", "Stats info on project store"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
};
class ProjectOpDetailsCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "project-op-details";
+ static constexpr char Description[] = "Detail info on ops inside a project store oplog";
+
ProjectOpDetailsCommand();
~ProjectOpDetailsCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"project-op-details", "Detail info on ops inside a project store oplog"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
bool m_Details = false;
bool m_OpDetails = false;
@@ -247,13 +273,16 @@ private:
class OplogMirrorCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "oplog-mirror";
+ static constexpr char Description[] = "Mirror project store oplog to file system";
+
OplogMirrorCommand();
~OplogMirrorCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"oplog-mirror", "Mirror oplog to file system"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
@@ -268,13 +297,16 @@ private:
class OplogValidateCommand : public ProjectStoreCommand
{
public:
+ static constexpr char Name[] = "oplog-validate";
+ static constexpr char Description[] = "Validate oplog for missing references";
+
OplogValidateCommand();
~OplogValidateCommand();
virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"oplog-validate", "Validate oplog for missing references"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
diff --git a/src/zen/cmds/rpcreplay_cmd.h b/src/zen/cmds/rpcreplay_cmd.h
index a6363b614..332a3126c 100644
--- a/src/zen/cmds/rpcreplay_cmd.h
+++ b/src/zen/cmds/rpcreplay_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class RpcStartRecordingCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "rpc-record-start";
+ static constexpr char Description[] = "Starts recording of cache rpc requests on a host";
+
RpcStartRecordingCommand();
~RpcStartRecordingCommand();
@@ -16,7 +19,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"rpc-record-start", "Starts recording of cache rpc requests on a host"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_RecordingPath;
};
@@ -24,6 +27,9 @@ private:
class RpcStopRecordingCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "rpc-record-stop";
+ static constexpr char Description[] = "Stops recording of cache rpc requests on a host";
+
RpcStopRecordingCommand();
~RpcStopRecordingCommand();
@@ -31,13 +37,16 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"rpc-record-stop", "Stops recording of cache rpc requests on a host"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
};
class RpcReplayCommand : public CacheStoreCommand
{
public:
+ static constexpr char Name[] = "rpc-record-replay";
+ static constexpr char Description[] = "Replays a previously recorded session of rpc requests";
+
RpcReplayCommand();
~RpcReplayCommand();
@@ -45,7 +54,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"rpc-record-replay", "Replays a previously recorded session of cache rpc requests to a target host"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_RecordingPath;
bool m_OnHost = false;
diff --git a/src/zen/cmds/run_cmd.h b/src/zen/cmds/run_cmd.h
index 570a2e63a..300c08c5b 100644
--- a/src/zen/cmds/run_cmd.h
+++ b/src/zen/cmds/run_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class RunCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "run";
+ static constexpr char Description[] = "Run command with special options";
+
RunCommand();
~RunCommand();
@@ -17,7 +20,7 @@ public:
virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
- cxxopts::Options m_Options{"run", "Run executable"};
+ cxxopts::Options m_Options{Name, Description};
int m_RunCount = 0;
int m_RunTime = -1;
std::string m_BaseDirectory;
diff --git a/src/zen/cmds/serve_cmd.h b/src/zen/cmds/serve_cmd.h
index ac74981f2..22f430948 100644
--- a/src/zen/cmds/serve_cmd.h
+++ b/src/zen/cmds/serve_cmd.h
@@ -11,6 +11,9 @@ namespace zen {
class ServeCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "serve";
+ static constexpr char Description[] = "Serve files from a directory";
+
ServeCommand();
~ServeCommand();
@@ -18,7 +21,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"serve", "Serve files from a tree"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
std::string m_ProjectName;
std::string m_OplogName;
diff --git a/src/zen/cmds/status_cmd.h b/src/zen/cmds/status_cmd.h
index dc103a196..df5df3066 100644
--- a/src/zen/cmds/status_cmd.h
+++ b/src/zen/cmds/status_cmd.h
@@ -11,6 +11,9 @@ namespace zen {
class StatusCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "status";
+ static constexpr char Description[] = "Show zen status";
+
StatusCommand();
~StatusCommand();
@@ -20,7 +23,7 @@ public:
private:
int GetLockFileEffectivePort() const;
- cxxopts::Options m_Options{"status", "Show zen status"};
+ cxxopts::Options m_Options{Name, Description};
uint16_t m_Port = 0;
std::filesystem::path m_DataDir;
};
diff --git a/src/zen/cmds/top_cmd.h b/src/zen/cmds/top_cmd.h
index 74167ecfd..aeb196558 100644
--- a/src/zen/cmds/top_cmd.h
+++ b/src/zen/cmds/top_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class TopCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "top";
+ static constexpr char Description[] = "Monitor zen server activity";
+
TopCommand();
~TopCommand();
@@ -16,12 +19,15 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"top", "Show dev UI"};
+ cxxopts::Options m_Options{Name, Description};
};
class PsCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "ps";
+ static constexpr char Description[] = "Enumerate running zen server instances";
+
PsCommand();
~PsCommand();
@@ -29,7 +35,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"ps", "Enumerate running Zen server instances"};
+ cxxopts::Options m_Options{Name, Description};
};
} // namespace zen
diff --git a/src/zen/cmds/trace_cmd.h b/src/zen/cmds/trace_cmd.h
index a6c9742b7..6eb0ba22b 100644
--- a/src/zen/cmds/trace_cmd.h
+++ b/src/zen/cmds/trace_cmd.h
@@ -6,11 +6,12 @@
namespace zen {
-/** Scrub storage
- */
class TraceCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "trace";
+ static constexpr char Description[] = "Control zen realtime tracing";
+
TraceCommand();
~TraceCommand();
@@ -18,7 +19,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"trace", "Control zen realtime tracing"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_HostName;
bool m_Stop = false;
std::string m_TraceHost;
diff --git a/src/zen/cmds/ui_cmd.cpp b/src/zen/cmds/ui_cmd.cpp
new file mode 100644
index 000000000..da06ce305
--- /dev/null
+++ b/src/zen/cmds/ui_cmd.cpp
@@ -0,0 +1,236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "ui_cmd.h"
+
+#include <zencore/except_fmt.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/process.h>
+#include <zenutil/consoletui.h>
+#include <zenutil/zenserverprocess.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+# include <shellapi.h>
+#endif
+
+namespace zen {
+
+namespace {
+
+ struct RunningServerInfo
+ {
+ uint16_t Port;
+ uint32_t Pid;
+ std::string SessionId;
+ std::string CmdLine;
+ };
+
+ static std::vector<RunningServerInfo> CollectRunningServers()
+ {
+ std::vector<RunningServerInfo> Servers;
+ ZenServerState State;
+ if (!State.InitializeReadOnly())
+ return Servers;
+
+ State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) {
+ StringBuilder<25> SessionSB;
+ Entry.GetSessionId().ToString(SessionSB);
+ std::error_code CmdLineEc;
+ std::string CmdLine = GetProcessCommandLine(static_cast<int>(Entry.Pid.load()), CmdLineEc);
+ Servers.push_back({Entry.EffectiveListenPort.load(), Entry.Pid.load(), std::string(SessionSB.c_str()), std::move(CmdLine)});
+ });
+
+ return Servers;
+ }
+
+} // namespace
+
+UiCommand::UiCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_options()("a,all", "Open dashboard for all running instances", cxxopts::value(m_All)->default_value("false"));
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
+ m_Options.add_option("",
+ "p",
+ "path",
+ "Dashboard path (default: /dashboard/)",
+ cxxopts::value(m_DashboardPath)->default_value("/dashboard/"),
+ "<path>");
+ m_Options.parse_positional("path");
+}
+
+UiCommand::~UiCommand()
+{
+}
+
+void
+UiCommand::OpenBrowser(std::string_view HostName)
+{
+ // Allow shortcuts for specifying dashboard path, and ensure it is in a format we expect
+ // (leading slash, trailing slash if no file extension)
+
+ if (!m_DashboardPath.empty())
+ {
+ if (m_DashboardPath[0] != '/')
+ {
+ m_DashboardPath = "/dashboard/" + m_DashboardPath;
+ }
+
+ if (m_DashboardPath.find_last_of('.') == std::string::npos && m_DashboardPath.back() != '/')
+ {
+ m_DashboardPath += '/';
+ }
+ }
+
+ bool Success = false;
+
+ ExtendableStringBuilder<256> FullUrl;
+ FullUrl << HostName << m_DashboardPath;
+
+#if ZEN_PLATFORM_WINDOWS
+ HINSTANCE Result = ShellExecuteA(nullptr, "open", FullUrl.c_str(), nullptr, nullptr, SW_SHOWNORMAL);
+ Success = reinterpret_cast<intptr_t>(Result) > 32;
+#else
+ // Validate URL doesn't contain shell metacharacters that could lead to command injection
+ std::string_view FullUrlView = FullUrl;
+ constexpr std::string_view DangerousChars = ";|&$`\\\"'<>(){}[]!#*?~\n\r";
+ if (FullUrlView.find_first_of(DangerousChars) != std::string_view::npos)
+ {
+ throw OptionParseException(fmt::format("URL contains invalid characters: '{}'", FullUrl), m_Options.help());
+ }
+
+# if ZEN_PLATFORM_MAC
+ std::string Command = fmt::format("open \"{}\"", FullUrl);
+# elif ZEN_PLATFORM_LINUX
+ std::string Command = fmt::format("xdg-open \"{}\"", FullUrl);
+# else
+ ZEN_NOT_IMPLEMENTED("Browser launching not implemented on this platform");
+# endif
+
+ Success = system(Command.c_str()) == 0;
+#endif
+
+ if (!Success)
+ {
+ throw zen::runtime_error("Failed to launch browser for '{}'", FullUrl);
+ }
+
+ ZEN_CONSOLE("Web browser launched for '{}' successfully", FullUrl);
+}
+
+void
+UiCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ using namespace std::literals;
+
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return;
+ }
+
+ // Resolve target server
+ uint16_t ServerPort = 0;
+
+ if (m_HostName.empty())
+ {
+ // Auto-discover running instances.
+ std::vector<RunningServerInfo> Servers = CollectRunningServers();
+
+ if (m_All)
+ {
+ if (Servers.empty())
+ {
+ throw OptionParseException("No running Zen server instances found", m_Options.help());
+ }
+
+ for (const auto& Server : Servers)
+ {
+ OpenBrowser(fmt::format("http://localhost:{}", Server.Port));
+ }
+ return;
+ }
+
+ // If multiple are found and we have an interactive terminal, present a picker
+ // instead of silently using the first one.
+ if (Servers.size() > 1 && IsTuiAvailable())
+ {
+ std::vector<std::string> Labels;
+ Labels.reserve(Servers.size() + 1);
+ Labels.push_back(fmt::format("(all {} instances)", Servers.size()));
+
+ const int32_t Cols = static_cast<int32_t>(TuiConsoleColumns());
+ constexpr int32_t kIndicator = 3; // " ▶ " or " " prefix
+ constexpr int32_t kSeparator = 2; // " " before cmdline
+ constexpr int32_t kEllipsis = 3; // "..."
+
+ for (const auto& Server : Servers)
+ {
+ std::string Label = fmt::format("port {:<5} pid {:<7} session {}", Server.Port, Server.Pid, Server.SessionId);
+
+ if (!Server.CmdLine.empty())
+ {
+ int32_t Available = Cols - kIndicator - kSeparator - static_cast<int32_t>(Label.size());
+ if (Available > kEllipsis)
+ {
+ Label += " ";
+ if (static_cast<int32_t>(Server.CmdLine.size()) <= Available)
+ {
+ Label += Server.CmdLine;
+ }
+ else
+ {
+ Label.append(Server.CmdLine, 0, static_cast<size_t>(Available - kEllipsis));
+ Label += "...";
+ }
+ }
+ }
+
+ Labels.push_back(std::move(Label));
+ }
+
+ int SelectedIdx = TuiPickOne("Multiple Zen server instances found. Select one to open:", Labels);
+ if (SelectedIdx < 0)
+ return; // User cancelled
+
+ if (SelectedIdx == 0)
+ {
+ // "All" selected
+ for (const auto& Server : Servers)
+ {
+ OpenBrowser(fmt::format("http://localhost:{}", Server.Port));
+ }
+ return;
+ }
+
+ ServerPort = Servers[SelectedIdx - 1].Port;
+ m_HostName = fmt::format("http://localhost:{}", ServerPort);
+ }
+
+ if (m_HostName.empty())
+ {
+ // Single or zero instances, or not an interactive terminal:
+ // fall back to default resolution (picks first instance or returns empty)
+ m_HostName = ResolveTargetHostSpec("", ServerPort);
+ }
+ }
+ else
+ {
+ if (m_All)
+ {
+ throw OptionParseException("--all cannot be used together with --hosturl", m_Options.help());
+ }
+ m_HostName = ResolveTargetHostSpec(m_HostName, ServerPort);
+ }
+
+ if (m_HostName.empty())
+ {
+ throw OptionParseException("Unable to resolve server specification", m_Options.help());
+ }
+
+ OpenBrowser(m_HostName);
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/ui_cmd.h b/src/zen/cmds/ui_cmd.h
new file mode 100644
index 000000000..c74cdbbd0
--- /dev/null
+++ b/src/zen/cmds/ui_cmd.h
@@ -0,0 +1,32 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+#include <filesystem>
+
+namespace zen {
+
+class UiCommand : public ZenCmdBase
+{
+public:
+ UiCommand();
+ ~UiCommand();
+
+ static constexpr char Name[] = "ui";
+ static constexpr char Description[] = "Launch web browser with zen server UI";
+
+ virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ void OpenBrowser(std::string_view HostName);
+
+ cxxopts::Options m_Options{Name, Description};
+ std::string m_HostName;
+ std::string m_DashboardPath = "/dashboard/";
+ bool m_All = false;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/up_cmd.h b/src/zen/cmds/up_cmd.h
index 2e822d5fc..270db7f88 100644
--- a/src/zen/cmds/up_cmd.h
+++ b/src/zen/cmds/up_cmd.h
@@ -11,6 +11,9 @@ namespace zen {
class UpCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "up";
+ static constexpr char Description[] = "Bring zen server up";
+
UpCommand();
~UpCommand();
@@ -18,7 +21,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"up", "Bring up zen service"};
+ cxxopts::Options m_Options{Name, Description};
uint16_t m_Port = 0;
bool m_ShowConsole = false;
bool m_ShowLog = false;
@@ -28,6 +31,9 @@ private:
class AttachCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "attach";
+ static constexpr char Description[] = "Add a sponsor process to a running zen service";
+
AttachCommand();
~AttachCommand();
@@ -35,7 +41,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"attach", "Add a sponsor process to a running zen service"};
+ cxxopts::Options m_Options{Name, Description};
uint16_t m_Port = 0;
int m_OwnerPid = 0;
std::filesystem::path m_DataDir;
@@ -44,6 +50,9 @@ private:
class DownCommand : public ZenCmdBase
{
public:
+ static constexpr char Name[] = "down";
+ static constexpr char Description[] = "Bring zen server down";
+
DownCommand();
~DownCommand();
@@ -51,7 +60,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"down", "Bring down zen service"};
+ cxxopts::Options m_Options{Name, Description};
uint16_t m_Port = 0;
bool m_ForceTerminate = false;
std::filesystem::path m_ProgramBaseDir;
diff --git a/src/zen/cmds/vfs_cmd.h b/src/zen/cmds/vfs_cmd.h
index 5deaa02fa..9009c774b 100644
--- a/src/zen/cmds/vfs_cmd.h
+++ b/src/zen/cmds/vfs_cmd.h
@@ -9,6 +9,9 @@ namespace zen {
class VfsCommand : public StorageCommand
{
public:
+ static constexpr char Name[] = "vfs";
+ static constexpr char Description[] = "Manage virtual file system";
+
VfsCommand();
~VfsCommand();
@@ -16,7 +19,7 @@ public:
virtual cxxopts::Options& Options() override { return m_Options; }
private:
- cxxopts::Options m_Options{"vfs", "Manage virtual file system"};
+ cxxopts::Options m_Options{Name, Description};
std::string m_Verb;
std::string m_HostName;
diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp
index adf0e61f0..10f5ad8e1 100644
--- a/src/zen/cmds/wipe_cmd.cpp
+++ b/src/zen/cmds/wipe_cmd.cpp
@@ -33,7 +33,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-namespace {
+namespace wipe_impl {
static std::atomic<bool> AbortFlag = false;
static std::atomic<bool> PauseFlag = false;
static bool IsVerbose = false;
@@ -49,10 +49,11 @@ namespace {
: GetMediumWorkerPool(EWorkloadType::Burst);
}
-#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \
- if (IsVerbose) \
- { \
- ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__); \
+#undef ZEN_CONSOLE_VERBOSE
+#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \
+ if (IsVerbose) \
+ { \
+ ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__); \
}
static void SignalCallbackHandler(int SigNum)
@@ -505,7 +506,7 @@ namespace {
}
return CleanWipe;
}
-} // namespace
+} // namespace wipe_impl
WipeCommand::WipeCommand()
{
@@ -532,6 +533,7 @@ WipeCommand::~WipeCommand() = default;
void
WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
+ using namespace wipe_impl;
ZEN_UNUSED(GlobalOptions);
signal(SIGINT, SignalCallbackHandler);
@@ -549,7 +551,7 @@ WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
ProgressMode = (IsVerbose || m_PlainProgress) ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty;
BoostWorkerThreads = m_BoostWorkerThreads;
- MakeSafeAbsolutePathÍnPlace(m_Directory);
+ MakeSafeAbsolutePathInPlace(m_Directory);
if (!IsDir(m_Directory))
{
diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp
index 6e6f5d863..af265d898 100644
--- a/src/zen/cmds/workspaces_cmd.cpp
+++ b/src/zen/cmds/workspaces_cmd.cpp
@@ -398,7 +398,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char**
}
else
{
- MakeSafeAbsolutePathÍnPlace(m_SystemRootDir);
+ MakeSafeAbsolutePathInPlace(m_SystemRootDir);
}
std::filesystem::path StatePath = m_SystemRootDir / "workspaces";
@@ -815,7 +815,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char**
if (Results.size() != m_ChunkIds.size())
{
throw std::runtime_error(
- fmt::format("failed to get workspace share batch - invalid result count recevied (expected: {}, received: {}",
+ fmt::format("failed to get workspace share batch - invalid result count received (expected: {}, received: {}",
m_ChunkIds.size(),
Results.size()));
}
diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp
index 83606df67..b758c061b 100644
--- a/src/zen/progressbar.cpp
+++ b/src/zen/progressbar.cpp
@@ -8,16 +8,12 @@
#include <zencore/logging.h>
#include <zencore/windows.h>
#include <zenremotestore/operationlogoutput.h>
+#include <zenutil/consoletui.h>
ZEN_THIRD_PARTY_INCLUDES_START
#include <gsl/gsl-lite.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
-#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
-# include <sys/ioctl.h>
-# include <unistd.h>
-#endif
-
//////////////////////////////////////////////////////////////////////////
namespace zen {
@@ -31,35 +27,12 @@ GetConsoleHandle()
}
#endif
-static bool
-CheckStdoutTty()
-{
-#if ZEN_PLATFORM_WINDOWS
- HANDLE hStdOut = GetConsoleHandle();
- DWORD dwMode = 0;
- static bool IsConsole = ::GetConsoleMode(hStdOut, &dwMode);
- return IsConsole;
-#else
- return isatty(fileno(stdout));
-#endif
-}
-
-static bool
-IsStdoutTty()
-{
- static bool StdoutIsTty = CheckStdoutTty();
- return StdoutIsTty;
-}
-
static void
OutputToConsoleRaw(const char* String, size_t Length)
{
#if ZEN_PLATFORM_WINDOWS
HANDLE hStdOut = GetConsoleHandle();
-#endif
-
-#if ZEN_PLATFORM_WINDOWS
- if (IsStdoutTty())
+ if (TuiIsStdoutTty())
{
WriteConsoleA(hStdOut, String, (DWORD)Length, 0, 0);
}
@@ -85,26 +58,6 @@ OutputToConsoleRaw(const StringBuilderBase& SB)
}
uint32_t
-GetConsoleColumns(uint32_t Default)
-{
-#if ZEN_PLATFORM_WINDOWS
- HANDLE hStdOut = GetConsoleHandle();
- CONSOLE_SCREEN_BUFFER_INFO csbi;
- if (GetConsoleScreenBufferInfo(hStdOut, &csbi) == TRUE)
- {
- return (uint32_t)(csbi.srWindow.Right - csbi.srWindow.Left + 1);
- }
-#else
- struct winsize w;
- if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &w) == 0)
- {
- return (uint32_t)w.ws_col;
- }
-#endif
- return Default;
-}
-
-uint32_t
GetUpdateDelayMS(ProgressBar::Mode InMode)
{
switch (InMode)
@@ -165,7 +118,7 @@ ProgressBar::PopLogOperation(Mode InMode)
}
ProgressBar::ProgressBar(Mode InMode, std::string_view InSubTask)
-: m_Mode((!IsStdoutTty() && InMode == Mode::Pretty) ? Mode::Plain : InMode)
+: m_Mode((!TuiIsStdoutTty() && InMode == Mode::Pretty) ? Mode::Plain : InMode)
, m_LastUpdateMS((uint64_t)-1)
, m_PausedMS(0)
, m_SubTask(InSubTask)
@@ -245,6 +198,7 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
const std::string Details = (!NewState.Details.empty()) ? fmt::format(": {}", NewState.Details) : "";
const std::string Output = fmt::format("{} {}% ({}){}\n", Task, PercentDone, NiceTimeSpanMs(ElapsedTimeMS), Details);
OutputToConsoleRaw(Output);
+ m_State = NewState;
}
else if (m_Mode == Mode::Pretty)
{
@@ -253,10 +207,11 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
size_t ProgressBarCount = (ProgressBarSize * PercentDone) / 100;
uint64_t Completed = NewState.TotalCount - NewState.RemainingCount;
uint64_t ETAElapsedMS = ElapsedTimeMS -= m_PausedMS;
- uint64_t ETAMS =
- (NewState.Status == State::EStatus::Running) && (PercentDone > 5) ? (ETAElapsedMS * NewState.RemainingCount) / Completed : 0;
+ uint64_t ETAMS = ((m_State.TotalCount == NewState.TotalCount) && (NewState.Status == State::EStatus::Running)) && (PercentDone > 5)
+ ? (ETAElapsedMS * NewState.RemainingCount) / Completed
+ : 0;
- uint32_t ConsoleColumns = GetConsoleColumns(1024);
+ uint32_t ConsoleColumns = TuiConsoleColumns(1024);
const std::string PercentString = fmt::format("{:#3}%", PercentDone);
@@ -435,19 +390,19 @@ class ConsoleOpLogOutput : public OperationLogOutput
{
public:
ConsoleOpLogOutput(zen::ProgressBar::Mode InMode) : m_Mode(InMode) {}
- virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args)
+ virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override
{
- logging::EmitConsoleLogMessage(LogLevel, Format, Args);
+ logging::EmitConsoleLogMessage(Point, Args);
}
- virtual void SetLogOperationName(std::string_view Name) { zen::ProgressBar::SetLogOperationName(m_Mode, Name); }
- virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount)
+ virtual void SetLogOperationName(std::string_view Name) override { zen::ProgressBar::SetLogOperationName(m_Mode, Name); }
+ virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override
{
zen::ProgressBar::SetLogOperationProgress(m_Mode, StepIndex, StepCount);
}
- virtual uint32_t GetProgressUpdateDelayMS() { return GetUpdateDelayMS(m_Mode); }
+ virtual uint32_t GetProgressUpdateDelayMS() override { return GetUpdateDelayMS(m_Mode); }
- virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); }
+ virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); }
private:
zen::ProgressBar::Mode m_Mode;
diff --git a/src/zen/progressbar.h b/src/zen/progressbar.h
index bbdb008d4..cb1c7023b 100644
--- a/src/zen/progressbar.h
+++ b/src/zen/progressbar.h
@@ -76,7 +76,6 @@ private:
};
uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode);
-uint32_t GetConsoleColumns(uint32_t Default);
OperationLogOutput* CreateConsoleLogOutput(ProgressBar::Mode InMode);
diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua
index ab094fef3..f889c3296 100644
--- a/src/zen/xmake.lua
+++ b/src/zen/xmake.lua
@@ -6,15 +6,12 @@ target("zen")
add_files("**.cpp")
add_files("zen.cpp", {unity_ignored = true })
add_deps("zencore", "zenhttp", "zenremotestore", "zenstore", "zenutil")
+ add_deps("zencompute", "zennet")
add_deps("cxxopts", "fmt")
add_packages("json11")
add_includedirs(".")
set_symbols("debug")
- if is_mode("release") then
- set_optimize("fastest")
- end
-
if is_plat("windows") then
add_files("zen.rc")
add_ldflags("/subsystem:console,5.02")
diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp
index 09a2e4f91..9a466da2e 100644
--- a/src/zen/zen.cpp
+++ b/src/zen/zen.cpp
@@ -11,6 +11,7 @@
#include "cmds/cache_cmd.h"
#include "cmds/copy_cmd.h"
#include "cmds/dedup_cmd.h"
+#include "cmds/exec_cmd.h"
#include "cmds/info_cmd.h"
#include "cmds/print_cmd.h"
#include "cmds/projectstore_cmd.h"
@@ -21,6 +22,7 @@
#include "cmds/status_cmd.h"
#include "cmds/top_cmd.h"
#include "cmds/trace_cmd.h"
+#include "cmds/ui_cmd.h"
#include "cmds/up_cmd.h"
#include "cmds/version_cmd.h"
#include "cmds/vfs_cmd.h"
@@ -39,7 +41,8 @@
#include <zencore/trace.h>
#include <zencore/windows.h>
#include <zenhttp/httpcommon.h>
-#include <zenutil/environmentoptions.h>
+#include <zenutil/config/environmentoptions.h>
+#include <zenutil/consoletui.h>
#include <zenutil/logging.h>
#include <zenutil/workerpools.h>
#include <zenutil/zenserverprocess.h>
@@ -53,7 +56,6 @@
#include "progressbar.h"
#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
# include <zencore/testing.h>
#endif
@@ -122,7 +124,7 @@ ZenCmdBase::ParseOptions(int argc, char** argv)
bool
ZenCmdBase::ParseOptions(cxxopts::Options& CmdOptions, int argc, char** argv)
{
- CmdOptions.set_width(GetConsoleColumns(80));
+ CmdOptions.set_width(TuiConsoleColumns(80));
cxxopts::ParseResult Result;
@@ -192,6 +194,84 @@ ZenCmdBase::GetSubCommand(cxxopts::Options&,
return argc;
}
+ZenSubCmdBase::ZenSubCmdBase(std::string_view Name, std::string_view Description)
+: m_SubOptions(std::string(Name), std::string(Description))
+{
+ m_SubOptions.add_options()("h,help", "Print help");
+}
+
+void
+ZenCmdWithSubCommands::AddSubCommand(ZenSubCmdBase& SubCmd)
+{
+ m_SubCommands.push_back(&SubCmd);
+}
+
+bool
+ZenCmdWithSubCommands::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOptions*/)
+{
+ return true;
+}
+
+void
+ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ std::vector<cxxopts::Options*> SubOptionPtrs;
+ SubOptionPtrs.reserve(m_SubCommands.size());
+ for (ZenSubCmdBase* SubCmd : m_SubCommands)
+ {
+ SubOptionPtrs.push_back(&SubCmd->SubOptions());
+ }
+
+ cxxopts::Options* MatchedSubOption = nullptr;
+ std::vector<char*> SubCommandArguments;
+ int ParentArgc = GetSubCommand(Options(), argc, argv, SubOptionPtrs, MatchedSubOption, SubCommandArguments);
+
+ if (!ParseOptions(Options(), ParentArgc, argv))
+ {
+ return;
+ }
+
+ if (MatchedSubOption == nullptr)
+ {
+ ExtendableStringBuilder<128> VerbList;
+ for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands)
+ {
+ if (!First)
+ {
+ VerbList.Append(", ");
+ }
+ VerbList.Append(SubCmd->SubOptions().program());
+ First = false;
+ }
+ throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), Options().help());
+ }
+
+ ZenSubCmdBase* MatchedSubCmd = nullptr;
+ for (ZenSubCmdBase* SubCmd : m_SubCommands)
+ {
+ if (&SubCmd->SubOptions() == MatchedSubOption)
+ {
+ MatchedSubCmd = SubCmd;
+ break;
+ }
+ }
+ ZEN_ASSERT(MatchedSubCmd != nullptr);
+
+ // Parse subcommand args before OnParentOptionsParsed so --help on the subcommand
+ // works without requiring parent options like --hosturl to be populated.
+ if (!ParseOptions(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data()))
+ {
+ return;
+ }
+
+ if (!OnParentOptionsParsed(GlobalOptions))
+ {
+ return;
+ }
+
+ MatchedSubCmd->Run(GlobalOptions);
+}
+
static ReturnCode
GetReturnCodeFromHttpResult(const HttpClientError& Ex)
{
@@ -316,22 +396,25 @@ main(int argc, char** argv)
}
#endif // ZEN_WITH_TRACE
- AttachCommand AttachCmd;
- BenchCommand BenchCmd;
- BuildsCommand BuildsCmd;
- CacheDetailsCommand CacheDetailsCmd;
- CacheGetCommand CacheGetCmd;
- CacheGenerateCommand CacheGenerateCmd;
- CacheInfoCommand CacheInfoCmd;
- CacheStatsCommand CacheStatsCmd;
- CopyCommand CopyCmd;
- CopyStateCommand CopyStateCmd;
- CreateOplogCommand CreateOplogCmd;
- CreateProjectCommand CreateProjectCmd;
- DedupCommand DedupCmd;
- DownCommand DownCmd;
- DropCommand DropCmd;
- DropProjectCommand ProjectDropCmd;
+ AttachCommand AttachCmd;
+ BenchCommand BenchCmd;
+ BuildsCommand BuildsCmd;
+ CacheDetailsCommand CacheDetailsCmd;
+ CacheGetCommand CacheGetCmd;
+ CacheGenerateCommand CacheGenerateCmd;
+ CacheInfoCommand CacheInfoCmd;
+ CacheStatsCommand CacheStatsCmd;
+ CopyCommand CopyCmd;
+ CopyStateCommand CopyStateCmd;
+ CreateOplogCommand CreateOplogCmd;
+ CreateProjectCommand CreateProjectCmd;
+ DedupCommand DedupCmd;
+ DownCommand DownCmd;
+ DropCommand DropCmd;
+ DropProjectCommand ProjectDropCmd;
+#if ZEN_WITH_COMPUTE_SERVICES
+ ExecCommand ExecCmd;
+#endif // ZEN_WITH_COMPUTE_SERVICES
ExportOplogCommand ExportOplogCmd;
FlushCommand FlushCmd;
GcCommand GcCmd;
@@ -360,6 +443,7 @@ main(int argc, char** argv)
LoggingCommand LoggingCmd;
TopCommand TopCmd;
TraceCommand TraceCmd;
+ UiCommand UiCmd;
UpCommand UpCmd;
VersionCommand VersionCmd;
VfsCommand VfsCmd;
@@ -375,53 +459,57 @@ main(int argc, char** argv)
const char* CmdSummary;
} Commands[] = {
// clang-format off
- {"attach", &AttachCmd, "Add a sponsor process to a running zen service"},
- {"bench", &BenchCmd, "Utility command for benchmarking"},
- {BuildsCommand::Name, &BuildsCmd, BuildsCommand::Description},
- {"cache-details", &CacheDetailsCmd, "Details on cache"},
- {"cache-info", &CacheInfoCmd, "Info on cache, namespace or bucket"},
+ {AttachCommand::Name, &AttachCmd, AttachCommand::Description},
+ {BenchCommand::Name, &BenchCmd, BenchCommand::Description},
+ {BuildsCommand::Name, &BuildsCmd, BuildsCommand::Description},
+ {CacheDetailsCommand::Name, &CacheDetailsCmd, CacheDetailsCommand::Description},
+ {CacheInfoCommand::Name, &CacheInfoCmd, CacheInfoCommand::Description},
{CacheGetCommand::Name, &CacheGetCmd, CacheGetCommand::Description},
{CacheGenerateCommand::Name, &CacheGenerateCmd, CacheGenerateCommand::Description},
- {"cache-stats", &CacheStatsCmd, "Stats on cache"},
- {"copy", &CopyCmd, "Copy file(s)"},
- {"copy-state", &CopyStateCmd, "Copy zen server disk state"},
- {"dedup", &DedupCmd, "Dedup files"},
- {"down", &DownCmd, "Bring zen server down"},
- {"drop", &DropCmd, "Drop cache namespace or bucket"},
- {"gc-status", &GcStatusCmd, "Garbage collect zen storage status check"},
- {"gc-stop", &GcStopCmd, "Request cancel of running garbage collection in zen storage"},
- {"gc", &GcCmd, "Garbage collect zen storage"},
- {"info", &InfoCmd, "Show high level Zen server information"},
- {"jobs", &JobCmd, "Show/cancel zen background jobs"},
- {"logs", &LoggingCmd, "Show/control zen logging"},
- {"oplog-create", &CreateOplogCmd, "Create a project oplog"},
- {"oplog-export", &ExportOplogCmd, "Export project store oplog"},
- {"oplog-import", &ImportOplogCmd, "Import project store oplog"},
- {"oplog-mirror", &OplogMirrorCmd, "Mirror project store oplog to file system"},
- {"oplog-snapshot", &SnapshotOplogCmd, "Snapshot project store oplog"},
+ {CacheStatsCommand::Name, &CacheStatsCmd, CacheStatsCommand::Description},
+ {CopyCommand::Name, &CopyCmd, CopyCommand::Description},
+ {CopyStateCommand::Name, &CopyStateCmd, CopyStateCommand::Description},
+ {DedupCommand::Name, &DedupCmd, DedupCommand::Description},
+ {DownCommand::Name, &DownCmd, DownCommand::Description},
+ {DropCommand::Name, &DropCmd, DropCommand::Description},
+#if ZEN_WITH_COMPUTE_SERVICES
+ {ExecCommand::Name, &ExecCmd, ExecCommand::Description},
+#endif
+ {GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description},
+ {GcStopCommand::Name, &GcStopCmd, GcStopCommand::Description},
+ {GcCommand::Name, &GcCmd, GcCommand::Description},
+ {InfoCommand::Name, &InfoCmd, InfoCommand::Description},
+ {JobCommand::Name, &JobCmd, JobCommand::Description},
+ {LoggingCommand::Name, &LoggingCmd, LoggingCommand::Description},
+ {CreateOplogCommand::Name, &CreateOplogCmd, CreateOplogCommand::Description},
+ {ExportOplogCommand::Name, &ExportOplogCmd, ExportOplogCommand::Description},
+ {ImportOplogCommand::Name, &ImportOplogCmd, ImportOplogCommand::Description},
+ {OplogMirrorCommand::Name, &OplogMirrorCmd, OplogMirrorCommand::Description},
+ {SnapshotOplogCommand::Name, &SnapshotOplogCmd, SnapshotOplogCommand::Description},
{OplogDownloadCommand::Name, &OplogDownload, OplogDownloadCommand::Description},
- {"oplog-validate", &OplogValidateCmd, "Validate oplog for missing references"},
- {"print", &PrintCmd, "Print compact binary object"},
- {"printpackage", &PrintPkgCmd, "Print compact binary package"},
- {"project-create", &CreateProjectCmd, "Create a project"},
- {"project-op-details", &ProjectOpDetailsCmd, "Detail info on ops inside a project store oplog"},
- {"project-drop", &ProjectDropCmd, "Drop project or project oplog"},
- {"project-info", &ProjectInfoCmd, "Info on project or project oplog"},
- {"project-stats", &ProjectStatsCmd, "Stats on project store"},
- {"ps", &PsCmd, "Enumerate running zen server instances"},
- {"rpc-record-replay", &RpcReplayCmd, "Replays a previously recorded session of rpc requests"},
- {"rpc-record-start", &RpcStartRecordingCmd, "Starts recording of cache rpc requests on a host"},
- {"rpc-record-stop", &RpcStopRecordingCmd, "Stops recording of cache rpc requests on a host"},
- {"run", &RunCmd, "Run command with special options"},
- {"scrub", &ScrubCmd, "Scrub zen storage (verify data integrity)"},
- {"serve", &ServeCmd, "Serve files from a directory"},
- {"status", &StatusCmd, "Show zen status"},
- {"top", &TopCmd, "Monitor zen server activity"},
- {"trace", &TraceCmd, "Control zen realtime tracing"},
- {"up", &UpCmd, "Bring zen server up"},
+ {OplogValidateCommand::Name, &OplogValidateCmd, OplogValidateCommand::Description},
+ {PrintCommand::Name, &PrintCmd, PrintCommand::Description},
+ {PrintPackageCommand::Name, &PrintPkgCmd, PrintPackageCommand::Description},
+ {CreateProjectCommand::Name, &CreateProjectCmd, CreateProjectCommand::Description},
+ {ProjectOpDetailsCommand::Name, &ProjectOpDetailsCmd, ProjectOpDetailsCommand::Description},
+ {DropProjectCommand::Name, &ProjectDropCmd, DropProjectCommand::Description},
+ {ProjectInfoCommand::Name, &ProjectInfoCmd, ProjectInfoCommand::Description},
+ {ProjectStatsCommand::Name, &ProjectStatsCmd, ProjectStatsCommand::Description},
+ {PsCommand::Name, &PsCmd, PsCommand::Description},
+ {RpcReplayCommand::Name, &RpcReplayCmd, RpcReplayCommand::Description},
+ {RpcStartRecordingCommand::Name, &RpcStartRecordingCmd, RpcStartRecordingCommand::Description},
+ {RpcStopRecordingCommand::Name, &RpcStopRecordingCmd, RpcStopRecordingCommand::Description},
+ {RunCommand::Name, &RunCmd, RunCommand::Description},
+ {ScrubCommand::Name, &ScrubCmd, ScrubCommand::Description},
+ {ServeCommand::Name, &ServeCmd, ServeCommand::Description},
+ {StatusCommand::Name, &StatusCmd, StatusCommand::Description},
+ {TopCommand::Name, &TopCmd, TopCommand::Description},
+ {TraceCommand::Name, &TraceCmd, TraceCommand::Description},
+ {UiCommand::Name, &UiCmd, UiCommand::Description},
+ {UpCommand::Name, &UpCmd, UpCommand::Description},
{VersionCommand::Name, &VersionCmd, VersionCommand::Description},
- {"vfs", &VfsCmd, "Manage virtual file system"},
- {"flush", &FlushCmd, "Flush storage"},
+ {VfsCommand::Name, &VfsCmd, VfsCommand::Description},
+ {FlushCommand::Name, &FlushCmd, FlushCommand::Description},
{WipeCommand::Name, &WipeCmd, WipeCommand::Description},
{WorkspaceCommand::Name, &WorkspaceCmd, WorkspaceCommand::Description},
{WorkspaceShareCommand::Name, &WorkspaceShareCmd, WorkspaceShareCommand::Description},
@@ -538,6 +626,9 @@ main(int argc, char** argv)
Options.add_options()("corelimit", "Limit concurrency", cxxopts::value(CoreLimit));
+ ZenLoggingCmdLineOptions LoggingCmdLineOptions;
+ LoggingCmdLineOptions.AddCliOptions(Options, GlobalOptions.LoggingConfig);
+
#if ZEN_WITH_TRACE
// We only have this in options for command line help purposes - we parse these argument separately earlier using
// GetTraceOptionsFromCommandline()
@@ -624,8 +715,8 @@ main(int argc, char** argv)
}
LimitHardwareConcurrency(CoreLimit);
-#if ZEN_USE_SENTRY
+#if ZEN_USE_SENTRY
{
EnvironmentOptions EnvOptions;
@@ -671,12 +762,19 @@ main(int argc, char** argv)
}
#endif
- zen::LoggingOptions LogOptions;
- LogOptions.IsDebug = GlobalOptions.IsDebug;
- LogOptions.IsVerbose = GlobalOptions.IsVerbose;
- LogOptions.AllowAsync = false;
+ LoggingCmdLineOptions.ApplyOptions(GlobalOptions.LoggingConfig);
+
+ const LoggingOptions LogOptions = {.IsDebug = GlobalOptions.IsDebug,
+ .IsVerbose = GlobalOptions.IsVerbose,
+ .IsTest = false,
+ .NoConsoleOutput = GlobalOptions.LoggingConfig.NoConsoleOutput,
+ .QuietConsole = GlobalOptions.LoggingConfig.QuietConsole,
+ .AbsLogFile = GlobalOptions.LoggingConfig.AbsLogFile,
+ .LogId = GlobalOptions.LoggingConfig.LogId};
zen::InitializeLogging(LogOptions);
+ ApplyLoggingOptions(Options, GlobalOptions.LoggingConfig);
+
std::set_terminate([]() {
void* Frames[8];
uint32_t FrameCount = GetCallstack(2, 8, Frames);
diff --git a/src/zen/zen.h b/src/zen/zen.h
index 05d1e4ec8..06e5356a6 100644
--- a/src/zen/zen.h
+++ b/src/zen/zen.h
@@ -5,7 +5,8 @@
#include <zencore/except.h>
#include <zencore/timer.h>
#include <zencore/zencore.h>
-#include <zenutil/commandlineoptions.h>
+#include <zenutil/config/commandlineoptions.h>
+#include <zenutil/config/loggingconfig.h>
namespace zen {
@@ -14,6 +15,8 @@ struct ZenCliOptions
bool IsDebug = false;
bool IsVerbose = false;
+ ZenLoggingConfig LoggingConfig;
+
// Arguments after " -- " on command line are passed through and not parsed
std::string PassthroughCommandLine;
std::string PassthroughArgs;
@@ -76,4 +79,41 @@ class CacheStoreCommand : public ZenCmdBase
virtual ZenCmdCategory& CommandCategory() const override { return g_CacheStoreCategory; }
};
+// Base for individual subcommands
+class ZenSubCmdBase
+{
+public:
+ ZenSubCmdBase(std::string_view Name, std::string_view Description);
+ virtual ~ZenSubCmdBase() = default;
+ cxxopts::Options& SubOptions() { return m_SubOptions; }
+ virtual void Run(const ZenCliOptions& GlobalOptions) = 0;
+
+protected:
+ cxxopts::Options m_SubOptions;
+};
+
+// Base for commands that host subcommands - handles all dispatch boilerplate
+class ZenCmdWithSubCommands : public ZenCmdBase
+{
+public:
+ void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) final;
+
+protected:
+ void AddSubCommand(ZenSubCmdBase& SubCmd);
+ virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions);
+
+private:
+ std::vector<ZenSubCmdBase*> m_SubCommands;
+};
+
+class CacheStoreCmdWithSubCommands : public ZenCmdWithSubCommands
+{
+ ZenCmdCategory& CommandCategory() const override { return g_CacheStoreCategory; }
+};
+
+class StorageCmdWithSubCommands : public ZenCmdWithSubCommands
+{
+ ZenCmdCategory& CommandCategory() const override { return g_StorageCategory; }
+};
+
} // namespace zen
diff --git a/src/zen/zen.rc b/src/zen/zen.rc
index 661d75011..0617681a7 100644
--- a/src/zen/zen.rc
+++ b/src/zen/zen.rc
@@ -7,7 +7,7 @@
LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US
#pragma code_page(1252)
-101 ICON "..\\UnrealEngine.ico"
+101 ICON "..\\zen.ico"
VS_VERSION_INFO VERSIONINFO
FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
diff --git a/src/zenbase/include/zenbase/refcount.h b/src/zenbase/include/zenbase/refcount.h
index 40ad7bca5..08bc6ae54 100644
--- a/src/zenbase/include/zenbase/refcount.h
+++ b/src/zenbase/include/zenbase/refcount.h
@@ -51,6 +51,9 @@ private:
* NOTE: Unlike RefCounted, this class deletes the derived type when the reference count reaches zero.
* It has no virtual destructor, so it's important that you either don't derive from it further,
* or ensure that the derived class has a virtual destructor.
+ *
+ * This class is useful when you want to avoid adding a vtable to a class just to implement
+ * reference counting.
*/
template<typename T>
diff --git a/src/zencompute-test/xmake.lua b/src/zencompute-test/xmake.lua
new file mode 100644
index 000000000..1207bdefd
--- /dev/null
+++ b/src/zencompute-test/xmake.lua
@@ -0,0 +1,8 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zencompute-test")
+ set_kind("binary")
+ set_group("tests")
+ add_headerfiles("**.h")
+ add_files("*.cpp")
+ add_deps("zencompute", "zencore")
diff --git a/src/zencompute-test/zencompute-test.cpp b/src/zencompute-test/zencompute-test.cpp
new file mode 100644
index 000000000..60aaeab1d
--- /dev/null
+++ b/src/zencompute-test/zencompute-test.cpp
@@ -0,0 +1,16 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencompute/zencompute.h>
+#include <zencore/testing.h>
+
+#include <zencore/memory/newdelete.h>
+
+int
+main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
+{
+#if ZEN_WITH_TESTS
+ return zen::testing::RunTestMain(argc, argv, "zencompute-test", zen::zencompute_forcelinktests);
+#else
+ return 0;
+#endif
+}
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
new file mode 100644
index 000000000..f5188123f
--- /dev/null
+++ b/src/zencompute/CLAUDE.md
@@ -0,0 +1,232 @@
+# zencompute Module
+
+Lambda-style compute function service. Accepts execution requests from HTTP clients, schedules them across local and remote runners, and tracks results.
+
+## Directory Structure
+
+```
+src/zencompute/
+├── include/zencompute/ # Public headers
+│ ├── computeservice.h # ComputeServiceSession public API
+│ ├── httpcomputeservice.h # HTTP service wrapper
+│ ├── orchestratorservice.h # Worker registry and orchestration
+│ ├── httporchestrator.h # HTTP orchestrator with WebSocket push
+│ ├── recordingreader.h # Recording/replay reader API
+│ ├── cloudmetadata.h # Cloud provider detection (AWS/Azure/GCP)
+│ └── mockimds.h # Test helper for cloud metadata
+├── runners/ # Execution backends
+│ ├── functionrunner.h/.cpp # Abstract base + BaseRunnerGroup/RunnerGroup
+│ ├── localrunner.h/.cpp # LocalProcessRunner (sandbox, monitoring, CPU sampling)
+│ ├── windowsrunner.h/.cpp # Windows AppContainer sandboxing + CreateProcessW
+│ ├── linuxrunner.h/.cpp # Linux user/mount/network namespace isolation
+│ ├── macrunner.h/.cpp # macOS Seatbelt sandboxing
+│ ├── winerunner.h/.cpp # Wine runner for Windows executables on Linux
+│ ├── remotehttprunner.h/.cpp # Remote HTTP submission to other zenserver instances
+│ └── deferreddeleter.h/.cpp # Background deletion of sandbox directories
+├── recording/ # Recording/replay subsystem
+│ ├── actionrecorder.h/.cpp # Write actions+attachments to disk
+│ └── recordingreader.cpp # Read recordings back
+├── timeline/
+│ └── workertimeline.h/.cpp # Per-worker action lifecycle event tracking
+├── testing/
+│ └── mockimds.cpp # Mock IMDS for cloud metadata tests
+├── computeservice.cpp # ComputeServiceSession::Impl (~1700 lines)
+├── httpcomputeservice.cpp # HTTP route registration and handlers (~900 lines)
+├── httporchestrator.cpp # Orchestrator HTTP API + WebSocket push
+├── orchestratorservice.cpp # Worker registry, health probing
+└── cloudmetadata.cpp # IMDS probing, termination monitoring
+```
+
+## Key Classes
+
+### `ComputeServiceSession` (computeservice.h)
+Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns:
+- Two `RunnerGroup`s: `m_LocalRunnerGroup`, `m_RemoteRunnerGroup`
+- Scheduler thread that drains `m_UpdatedActions` and drives state transitions
+- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap`
+- Queue map: `m_Queues` (QueueEntry objects)
+- Action history ring: `m_ActionHistory` (bounded deque, default 1000)
+
+**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset.
+
+### `RunnerAction` (runners/functionrunner.h)
+Shared ref-counted struct representing one action through its lifecycle.
+
+**Key fields:**
+- `ActionLsn` — global unique sequence number
+- `QueueId` — 0 for implicit/unqueued actions
+- `Worker` — descriptor + content hash
+- `ActionObj` — CbObject with the action spec
+- `CpuUsagePercent` / `CpuSeconds` — atomics updated by monitor thread
+- `RetryCount` — atomic int tracking how many times the action has been rescheduled
+- `Timestamps[State::_Count]` — timestamp of each state transition
+
+**State machine (forward-only under normal flow, atomic):**
+```
+New → Pending → Submitting → Running → Completed
+ → Failed
+ → Abandoned
+ → Cancelled
+```
+`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline.
+
+### `LocalProcessRunner` (runners/localrunner.h)
+Base for all local execution. Platform runners subclass this and override:
+- `SubmitAction()` — fork/exec the worker process
+- `SweepRunningActions()` — poll for process exit (waitpid / WaitForSingleObject)
+- `CancelRunningActions()` — signal all processes during shutdown
+- `SampleProcessCpu(RunningAction&)` — read platform CPU usage (no-op default)
+
+**Infrastructure owned by LocalProcessRunner:**
+- Monitor thread — calls `SweepRunningActions()` then `SampleRunningProcessCpu()` in a loop
+- `m_RunningMap` — `RwLock`-guarded map of `Lsn → RunningAction`
+- `DeferredDirectoryDeleter` — sandbox directories are queued for async deletion
+- `PrepareActionSubmission()` — shared preamble (capacity check, sandbox creation, worker manifesting, input decompression)
+- `ProcessCompletedActions()` — shared post-processing (gather outputs, set state, enqueue deletion)
+
+**CPU sampling:** `SampleRunningProcessCpu()` iterates `m_RunningMap` under shared lock, calls `SampleProcessCpu()` per entry, throttled to every 5 seconds per action. Platform implementations:
+- Linux: `/proc/{pid}/stat` utime+stime jiffies ÷ `_SC_CLK_TCK`
+- Windows: `GetProcessTimes()` in 100ns intervals ÷ 10,000,000
+- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000
+
+### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h)
+Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity.
+
+### `HttpComputeService` (include/zencompute/httpcomputeservice.h)
+Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service.
+
+## Action Lifecycle (End to End)
+
+1. **HTTP POST** → `HttpComputeService` ingests attachments, calls `EnqueueAction()`
+2. **Enqueue** → creates `RunnerAction` (New → Pending), calls `PostUpdate()`
+3. **PostUpdate** → appends to `m_UpdatedActions`, signals scheduler thread event
+4. **Scheduler thread** → drains `m_UpdatedActions`, drives pending actions to runners
+5. **Runner `SubmitAction()`** → Pending → Submitting (on runner's worker pool thread)
+6. **Process launch** → Submitting → Running, added to `m_RunningMap`
+7. **Monitor thread `SweepRunningActions()`** → detects exit, gathers outputs
+8. **`ProcessCompletedActions()`** → Running → Completed/Failed/Abandoned, `PostUpdate()`
+9. **Scheduler thread `HandleActionUpdates()`** — for Failed/Abandoned actions, checks retry limit; if retries remain, calls `ResetActionStateToPending()` which loops back to step 3. Otherwise moves to `m_ResultsMap`, records history, notifies queue.
+10. **Client `GET /jobs/{lsn}`** → returns result from `m_ResultsMap`, schedules retirement
+
+### Action Rescheduling
+
+Actions that fail or are abandoned can be automatically retried or manually rescheduled via the API.
+
+**Automatic retry (scheduler path):** In `HandleActionUpdates()`, when a Failed or Abandoned state is detected, the scheduler checks `RetryCount < GetMaxRetriesForQueue(QueueId)`. If retries remain, the action is removed from active maps and `ResetActionStateToPending()` is called, which re-enters it into the scheduler pipeline. The action keeps its original LSN so clients can continue polling with the same identifier.
+
+**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure.
+
+**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit.
+
+**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure.
+
+## Queue System
+
+Queues group actions from a single client session. A `QueueEntry` (internal) tracks:
+- `State` — `std::atomic<QueueState>` lifecycle state (Active → Draining → Cancelled)
+- `ActiveCount` — pending + running actions (atomic)
+- `CompletedCount / FailedCount / AbandonedCount / CancelledCount` (atomics)
+- `ActiveLsns` — for cancellation lookup (under `m_Lock`)
+- `FinishedLsns` — moved here when actions complete
+- `IdleSince` — used for 15-minute automatic expiry
+- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit
+
+**Queue state machine (`QueueState` enum):**
+```
+Active → Draining → Cancelled
+ \ ↑
+ ─────────────────────/
+```
+- **Active** — accepts new work, schedules pending work, finishes running work (initial state)
+- **Draining** — rejects new work, finishes existing work (one-way via CAS from Active; cannot override Cancelled)
+- **Cancelled** — rejects new work, actively cancels in-flight work (reachable from Active or Draining)
+
+Key operations:
+- `CreateQueue(Tag)` → returns `QueueId`
+- `EnqueueActionToQueue(QueueId, ...)` → action's `QueueId` field is set at creation
+- `CancelQueue(QueueId)` → marks all active LSNs for cancellation
+- `DrainQueue(QueueId)` → stops accepting new submissions; existing work finishes naturally (irreversible)
+- `GetQueueCompleted(QueueId)` → CbWriter output of finished results
+- Queue references in HTTP routes accept either a decimal ID or an Oid token (24-hex), resolved by `ResolveQueueRef()`
+
+## HTTP API
+
+All routes registered in `HttpComputeService` constructor. Prefix is configured externally (typically `/compute`).
+
+### Global endpoints
+| Method | Path | Description |
+|--------|------|-------------|
+| POST | `abandon` | Transition session to Abandoned state (409 if invalid) |
+| GET | `jobs/history` | Action history (last N, with timestamps per state) |
+| GET | `jobs/running` | In-flight actions with CPU metrics |
+| GET | `jobs/completed` | Actions with results available |
+| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire |
+| POST | `jobs/{worker}` | Submit action for specific worker |
+| POST | `jobs` | Submit action (worker resolved from descriptor) |
+| GET | `workers` | List worker IDs |
+| GET | `workers/all` | All workers with full descriptors |
+| GET/POST | `workers/{worker}` | Get/register worker |
+
+### Queue-scoped endpoints
+Queue ref is capture(1) in all `queues/{queueref}/...` routes.
+
+| Method | Path | Description |
+|--------|------|-------------|
+| GET | `queues` | List queue IDs |
+| POST | `queues` | Create queue |
+| GET/DELETE | `queues/{queueref}` | Status / delete |
+| POST | `queues/{queueref}/drain` | Drain queue (irreversible; rejects new submissions) |
+| GET | `queues/{queueref}/completed` | Queue's completed results |
+| GET | `queues/{queueref}/history` | Queue's action history |
+| GET | `queues/{queueref}/running` | Queue's running actions |
+| POST | `queues/{queueref}/jobs` | Submit to queue |
+| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule |
+| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) |
+
+Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes.
+
+## Concurrency Model
+
+**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks:
+1. `m_ResultsLock`
+2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock")
+3. `m_PendingLock`
+4. `m_QueueLock`
+
+**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`.
+
+**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates.
+
+**Thread ownership:**
+- Scheduler thread — drives state transitions, owns `m_PendingActions`
+- Monitor thread (per runner) — polls process completion, owns `m_RunningMap` via shared lock
+- Worker pool threads — async submission, brief `SubmitAction()` calls
+- HTTP threads — read-only access to results, queue status
+
+## Sandbox Layout
+
+Each action gets a unique numbered directory under `m_SandboxPath`:
+```
+scratch/{counter}/
+ worker/ ← worker binaries (or bind-mounted on Linux)
+ inputs/ ← decompressed action inputs
+ outputs/ ← written by worker process
+```
+
+On Linux with sandboxing enabled, the process runs in a pivot-rooted namespace with `/usr`, `/lib`, `/etc`, `/worker` bind-mounted read-only and a tmpfs `/dev`.
+
+## Adding a New HTTP Endpoint
+
+1. Register the route in the `HttpComputeService` constructor in `httpcomputeservice.cpp`
+2. If the handler is shared between top-level and a `queues/{queueref}/...` variant, extract it as a private helper method declared in `httpcomputeservice.h`
+3. Queue-scoped routes validate the queue ref with `ResolveQueueRef(HttpReq, Req.GetCapture(1))` which writes an error response and returns 0 on failure
+4. Use `CbObjectWriter` for response bodies; emit via `HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save())`
+5. Conditional fields (e.g., optional CPU metrics): emit inside `if (value > 0.0f)` / `if (value >= 0.0f)` guards to omit absent values rather than emitting sentinel values
+
+## Adding a New Runner Platform
+
+1. Subclass `LocalProcessRunner`, add `h`/`cpp` files in `runners/`
+2. Override `SubmitAction()`, `SweepRunningActions()`, `CancelRunningActions()`, and optionally `CancelAction(int)` and `SampleProcessCpu(RunningAction&)`
+3. `SampleProcessCpu()` must update both `Running.Action->CpuSeconds` (unconditionally from the absolute OS value) and `Running.Action->CpuUsagePercent` (delta-based, only after second sample)
+4. `ProcessHandle` convention: store pid as `reinterpret_cast<void*>(static_cast<intptr_t>(pid))` for consistency with the base class
+5. Register in `ComputeServiceSession::AddLocalRunner()` in `computeservice.cpp`
diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp
new file mode 100644
index 000000000..65bac895f
--- /dev/null
+++ b/src/zencompute/cloudmetadata.cpp
@@ -0,0 +1,1014 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencompute/cloudmetadata.h>
+
+#include <zencore/basicfile.h>
+#include <zencore/filesystem.h>
+#include <zencore/string.h>
+#include <zencore/trace.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::compute {
+
+// All major cloud providers expose instance metadata at this link-local address.
+// It is only routable from within a cloud VM; on bare-metal the TCP connect will
+// fail, which is how we distinguish cloud from non-cloud environments.
+static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254";
+
+// Short connect timeout so that detection on non-cloud machines is fast. The IMDS
+// is a local service on the hypervisor so 200ms is generous for actual cloud VMs.
+static constexpr auto kImdsTimeout = std::chrono::milliseconds{200};
+
+std::string_view
+ToString(CloudProvider Provider)
+{
+ switch (Provider)
+ {
+ case CloudProvider::AWS:
+ return "AWS";
+ case CloudProvider::Azure:
+ return "Azure";
+ case CloudProvider::GCP:
+ return "GCP";
+ default:
+ return "None";
+ }
+}
+
+CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint))
+{
+}
+
+CloudMetadata::CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint)
+: m_Log(logging::Get("cloud"))
+, m_DataDir(std::move(DataDir))
+, m_ImdsEndpoint(std::move(ImdsEndpoint))
+{
+ ZEN_TRACE_CPU("CloudMetadata::CloudMetadata");
+
+ std::error_code Ec;
+ std::filesystem::create_directories(m_DataDir, Ec);
+
+ DetectProvider();
+
+ if (m_Info.Provider != CloudProvider::None)
+ {
+ StartTerminationMonitor();
+ }
+}
+
+CloudMetadata::~CloudMetadata()
+{
+ ZEN_TRACE_CPU("CloudMetadata::~CloudMetadata");
+ m_MonitorEnabled = false;
+ m_MonitorEvent.Set();
+ if (m_MonitorThread.joinable())
+ {
+ m_MonitorThread.join();
+ }
+}
+
+CloudProvider
+CloudMetadata::GetProvider() const
+{
+ return m_InfoLock.WithSharedLock([&] { return m_Info.Provider; });
+}
+
+CloudInstanceInfo
+CloudMetadata::GetInstanceInfo() const
+{
+ return m_InfoLock.WithSharedLock([&] { return m_Info; });
+}
+
+bool
+CloudMetadata::IsTerminationPending() const
+{
+ return m_TerminationPending.load(std::memory_order_relaxed);
+}
+
+std::string
+CloudMetadata::GetTerminationReason() const
+{
+ return m_ReasonLock.WithSharedLock([&] { return m_TerminationReason; });
+}
+
+void
+CloudMetadata::Describe(CbWriter& Writer) const
+{
+ ZEN_TRACE_CPU("CloudMetadata::Describe");
+ CloudInstanceInfo Info = GetInstanceInfo();
+
+ if (Info.Provider == CloudProvider::None)
+ {
+ return;
+ }
+
+ Writer.BeginObject("cloud");
+ Writer << "provider" << ToString(Info.Provider);
+ Writer << "instance_id" << Info.InstanceId;
+ Writer << "availability_zone" << Info.AvailabilityZone;
+ Writer << "is_spot" << Info.IsSpot;
+ Writer << "is_autoscaling" << Info.IsAutoscaling;
+ Writer << "termination_pending" << IsTerminationPending();
+
+ if (IsTerminationPending())
+ {
+ Writer << "termination_reason" << GetTerminationReason();
+ }
+
+ Writer.EndObject();
+}
+
+void
+CloudMetadata::DetectProvider()
+{
+ ZEN_TRACE_CPU("CloudMetadata::DetectProvider");
+
+ if (TryDetectAWS())
+ {
+ return;
+ }
+
+ if (TryDetectAzure())
+ {
+ return;
+ }
+
+ if (TryDetectGCP())
+ {
+ return;
+ }
+
+ ZEN_DEBUG("no cloud provider detected");
+}
+
+// AWS detection uses IMDSv2 which requires a session token obtained via PUT before
+// any GET requests are allowed. This is more secure than IMDSv1 (which allowed
+// unauthenticated GETs) and is the default on modern EC2 instances. The token has
+// a 300-second TTL and is reused for termination polling.
+bool
+CloudMetadata::TryDetectAWS()
+{
+ ZEN_TRACE_CPU("CloudMetadata::TryDetectAWS");
+
+ std::filesystem::path SentinelPath = m_DataDir / ".isNotAWS";
+
+ if (HasSentinelFile(SentinelPath))
+ {
+ ZEN_DEBUG("skipping AWS detection - negative cache hit");
+ return false;
+ }
+
+ ZEN_DEBUG("probing AWS IMDS");
+
+ try
+ {
+ HttpClient ImdsClient(m_ImdsEndpoint,
+ {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}});
+
+ // IMDSv2: acquire session token. The TTL header is mandatory; we request
+ // 300s which is sufficient for the detection phase. The token is also
+ // stored in m_AwsToken for reuse by the termination polling thread.
+ HttpClient::KeyValueMap TokenHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token-ttl-seconds", "300"});
+ HttpClient::Response TokenResponse = ImdsClient.Put("/latest/api/token", IoBuffer{}, TokenHeaders);
+
+ if (!TokenResponse.IsSuccess())
+ {
+ ZEN_DEBUG("AWS IMDS token request failed ({}), not on AWS", static_cast<int>(TokenResponse.StatusCode));
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+
+ m_AwsToken = std::string(TokenResponse.AsText());
+
+ HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken});
+
+ HttpClient::Response IdResponse = ImdsClient.Get("/latest/meta-data/instance-id", AuthHeaders);
+ if (IdResponse.IsSuccess())
+ {
+ m_Info.InstanceId = std::string(IdResponse.AsText());
+ }
+
+ HttpClient::Response AzResponse = ImdsClient.Get("/latest/meta-data/placement/availability-zone", AuthHeaders);
+ if (AzResponse.IsSuccess())
+ {
+ m_Info.AvailabilityZone = std::string(AzResponse.AsText());
+ }
+
+ // "spot" vs "on-demand" — determines whether the instance can be
+ // reclaimed by AWS with a 2-minute warning
+ HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders);
+ if (LifecycleResponse.IsSuccess())
+ {
+ m_Info.IsSpot = (LifecycleResponse.AsText() == "spot");
+ }
+
+ // This endpoint only exists on instances managed by an Auto Scaling
+ // Group. A successful response (regardless of value) means autoscaling.
+ HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders);
+ if (AutoscaleResponse.IsSuccess())
+ {
+ m_Info.IsAutoscaling = true;
+ }
+
+ m_Info.Provider = CloudProvider::AWS;
+
+ ZEN_INFO("detected AWS instance: id={}, az={}, spot={}, autoscaling={}",
+ m_Info.InstanceId,
+ m_Info.AvailabilityZone,
+ m_Info.IsSpot,
+ m_Info.IsAutoscaling);
+
+ return true;
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("AWS IMDS probe failed: {}", Ex.what());
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+}
+
+// Azure IMDS returns a single JSON document for the entire instance metadata,
+// unlike AWS and GCP which use separate plain-text endpoints per field. The
+// "Metadata: true" header is required; requests without it are rejected.
+// The api-version parameter is mandatory and pins the response schema.
+bool
+CloudMetadata::TryDetectAzure()
+{
+ ZEN_TRACE_CPU("CloudMetadata::TryDetectAzure");
+
+ std::filesystem::path SentinelPath = m_DataDir / ".isNotAzure";
+
+ if (HasSentinelFile(SentinelPath))
+ {
+ ZEN_DEBUG("skipping Azure detection - negative cache hit");
+ return false;
+ }
+
+ ZEN_DEBUG("probing Azure IMDS");
+
+ try
+ {
+ HttpClient ImdsClient(m_ImdsEndpoint,
+ {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}});
+
+ HttpClient::KeyValueMap MetadataHeaders({
+ std::pair<std::string_view, std::string_view>{"Metadata", "true"},
+ });
+
+ HttpClient::Response InstanceResponse = ImdsClient.Get("/metadata/instance?api-version=2021-02-01", MetadataHeaders);
+
+ if (!InstanceResponse.IsSuccess())
+ {
+ ZEN_DEBUG("Azure IMDS request failed ({}), not on Azure", static_cast<int>(InstanceResponse.StatusCode));
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+
+ std::string JsonError;
+ const json11::Json Json = json11::Json::parse(std::string(InstanceResponse.AsText()), JsonError);
+
+ if (!JsonError.empty())
+ {
+ ZEN_DEBUG("Azure IMDS returned invalid JSON: {}", JsonError);
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+
+ const json11::Json& Compute = Json["compute"];
+
+ m_Info.InstanceId = Compute["vmId"].string_value();
+ m_Info.AvailabilityZone = Compute["location"].string_value();
+
+ // Azure spot VMs have priority "Spot"; regular VMs have "Regular"
+ std::string Priority = Compute["priority"].string_value();
+ m_Info.IsSpot = (Priority == "Spot");
+
+ // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling
+ std::string VmssName = Compute["vmScaleSetName"].string_value();
+ m_Info.IsAutoscaling = !VmssName.empty();
+
+ m_Info.Provider = CloudProvider::Azure;
+
+ ZEN_INFO("detected Azure instance: id={}, location={}, spot={}, vmss={}",
+ m_Info.InstanceId,
+ m_Info.AvailabilityZone,
+ m_Info.IsSpot,
+ m_Info.IsAutoscaling);
+
+ return true;
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("Azure IMDS probe failed: {}", Ex.what());
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+}
+
+// GCP requires the "Metadata-Flavor: Google" header on all IMDS requests.
+// Unlike AWS, there is no session token; the header itself is the auth mechanism
+// (it prevents SSRF attacks since browsers won't send custom headers to the
+// metadata endpoint). Each metadata field is fetched from a separate URL.
+bool
+CloudMetadata::TryDetectGCP()
+{
+ ZEN_TRACE_CPU("CloudMetadata::TryDetectGCP");
+
+ std::filesystem::path SentinelPath = m_DataDir / ".isNotGCP";
+
+ if (HasSentinelFile(SentinelPath))
+ {
+ ZEN_DEBUG("skipping GCP detection - negative cache hit");
+ return false;
+ }
+
+ ZEN_DEBUG("probing GCP metadata service");
+
+ try
+ {
+ HttpClient ImdsClient(m_ImdsEndpoint,
+ {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}});
+
+ HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"});
+
+ // Fetch instance ID
+ HttpClient::Response IdResponse = ImdsClient.Get("/computeMetadata/v1/instance/id", MetadataHeaders);
+
+ if (!IdResponse.IsSuccess())
+ {
+ ZEN_DEBUG("GCP metadata request failed ({}), not on GCP", static_cast<int>(IdResponse.StatusCode));
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+
+ m_Info.InstanceId = std::string(IdResponse.AsText());
+
+ // GCP returns the fully-qualified zone path "projects/<num>/zones/<zone>".
+ // Strip the prefix to get just the zone name (e.g. "us-central1-a").
+ HttpClient::Response ZoneResponse = ImdsClient.Get("/computeMetadata/v1/instance/zone", MetadataHeaders);
+ if (ZoneResponse.IsSuccess())
+ {
+ std::string_view Zone = ZoneResponse.AsText();
+ if (auto Pos = Zone.rfind('/'); Pos != std::string_view::npos)
+ {
+ Zone = Zone.substr(Pos + 1);
+ }
+ m_Info.AvailabilityZone = std::string(Zone);
+ }
+
+ // Check for preemptible/spot (scheduling/preemptible returns "TRUE" or "FALSE")
+ HttpClient::Response PreemptibleResponse = ImdsClient.Get("/computeMetadata/v1/instance/scheduling/preemptible", MetadataHeaders);
+ if (PreemptibleResponse.IsSuccess())
+ {
+ m_Info.IsSpot = (PreemptibleResponse.AsText() == "TRUE");
+ }
+
+ // Check for maintenance event
+ HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders);
+ if (MaintenanceResponse.IsSuccess())
+ {
+ std::string_view Event = MaintenanceResponse.AsText();
+ if (!Event.empty() && Event != "NONE")
+ {
+ m_TerminationPending = true;
+ m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); });
+ }
+ }
+
+ m_Info.Provider = CloudProvider::GCP;
+
+ ZEN_INFO("detected GCP instance: id={}, az={}, spot={}", m_Info.InstanceId, m_Info.AvailabilityZone, m_Info.IsSpot);
+
+ return true;
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("GCP metadata probe failed: {}", Ex.what());
+ WriteSentinelFile(SentinelPath);
+ return false;
+ }
+}
+
+// Sentinel files are empty marker files whose mere existence signals that a
+// previous detection attempt for a given provider failed. This avoids paying
+// the connect-timeout cost on every startup for providers that are known to
+// be absent. The files persist across process restarts; delete them manually
+// (or remove the DataDir) to force re-detection.
+void
+CloudMetadata::WriteSentinelFile(const std::filesystem::path& Path)
+{
+ try
+ {
+ BasicFile File;
+ File.Open(Path, BasicFile::Mode::kTruncate);
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("failed to write sentinel file '{}': {}", Path.string(), Ex.what());
+ }
+}
+
+bool
+CloudMetadata::HasSentinelFile(const std::filesystem::path& Path) const
+{
+ return zen::IsFile(Path);
+}
+
+void
+CloudMetadata::ClearSentinelFiles()
+{
+ std::error_code Ec;
+ std::filesystem::remove(m_DataDir / ".isNotAWS", Ec);
+ std::filesystem::remove(m_DataDir / ".isNotAzure", Ec);
+ std::filesystem::remove(m_DataDir / ".isNotGCP", Ec);
+}
+
+void
+CloudMetadata::StartTerminationMonitor()
+{
+ ZEN_INFO("starting cloud termination monitor for {} instance {}", ToString(m_Info.Provider), m_Info.InstanceId);
+
+ m_MonitorThread = std::thread{&CloudMetadata::TerminationMonitorThread, this};
+}
+
+void
+CloudMetadata::TerminationMonitorThread()
+{
+ SetCurrentThreadName("cloud_term_mon");
+
+ // Poll every 5 seconds. The Event is used as an interruptible sleep so
+ // that the destructor can wake us up immediately for a clean shutdown.
+ while (m_MonitorEnabled)
+ {
+ m_MonitorEvent.Wait(5000);
+ m_MonitorEvent.Reset();
+
+ if (!m_MonitorEnabled)
+ {
+ return;
+ }
+
+ PollTermination();
+ }
+}
+
+void
+CloudMetadata::PollTermination()
+{
+ try
+ {
+ CloudProvider Provider = m_InfoLock.WithSharedLock([&] { return m_Info.Provider; });
+
+ if (Provider == CloudProvider::AWS)
+ {
+ PollAWSTermination();
+ }
+ else if (Provider == CloudProvider::Azure)
+ {
+ PollAzureTermination();
+ }
+ else if (Provider == CloudProvider::GCP)
+ {
+ PollGCPTermination();
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("termination poll error: {}", Ex.what());
+ }
+}
+
+// AWS termination signals:
+// - /spot/instance-action: returns 200 with a JSON body ~2 minutes before
+// a spot instance is reclaimed. Returns 404 when no action is pending.
+// - /autoscaling/target-lifecycle-state: returns the ASG lifecycle state.
+// "InService" is normal; anything else (e.g. "Terminated:Wait") means
+// the instance is being cycled out.
+void
+CloudMetadata::PollAWSTermination()
+{
+ ZEN_TRACE_CPU("CloudMetadata::PollAWSTermination");
+
+ HttpClient ImdsClient(m_ImdsEndpoint,
+ {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}});
+
+ HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken});
+
+ HttpClient::Response SpotResponse = ImdsClient.Get("/latest/meta-data/spot/instance-action", AuthHeaders);
+ if (SpotResponse.IsSuccess())
+ {
+ if (!m_TerminationPending.exchange(true))
+ {
+ m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS spot interruption: {}", SpotResponse.AsText()); });
+ ZEN_WARN("AWS spot interruption detected: {}", SpotResponse.AsText());
+ }
+ return;
+ }
+
+ HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders);
+ if (AutoscaleResponse.IsSuccess())
+ {
+ std::string_view State = AutoscaleResponse.AsText();
+ if (State.find("InService") == std::string_view::npos)
+ {
+ if (!m_TerminationPending.exchange(true))
+ {
+ m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS autoscaling lifecycle: {}", State); });
+ ZEN_WARN("AWS autoscaling termination detected: {}", State);
+ }
+ }
+ }
+}
+
+// Azure Scheduled Events API returns a JSON array of upcoming platform events.
+// We care about "Preempt" (spot eviction), "Terminate", and "Reboot" events.
+// Other event types like "Freeze" (live migration) are non-destructive and
+// ignored. The Events array is empty when nothing is pending.
+void
+CloudMetadata::PollAzureTermination()
+{
+ ZEN_TRACE_CPU("CloudMetadata::PollAzureTermination");
+
+ HttpClient ImdsClient(m_ImdsEndpoint,
+ {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}});
+
+ HttpClient::KeyValueMap MetadataHeaders({
+ std::pair<std::string_view, std::string_view>{"Metadata", "true"},
+ });
+
+ HttpClient::Response EventsResponse = ImdsClient.Get("/metadata/scheduledevents?api-version=2020-07-01", MetadataHeaders);
+
+ if (!EventsResponse.IsSuccess())
+ {
+ return;
+ }
+
+ std::string JsonError;
+ const json11::Json Json = json11::Json::parse(std::string(EventsResponse.AsText()), JsonError);
+
+ if (!JsonError.empty())
+ {
+ return;
+ }
+
+ const json11::Json::array& Events = Json["Events"].array_items();
+ for (const auto& Evt : Events)
+ {
+ std::string EventType = Evt["EventType"].string_value();
+ if (EventType == "Preempt" || EventType == "Terminate" || EventType == "Reboot")
+ {
+ if (!m_TerminationPending.exchange(true))
+ {
+ std::string EventStatus = Evt["EventStatus"].string_value();
+ m_ReasonLock.WithExclusiveLock(
+ [&] { m_TerminationReason = fmt::format("Azure scheduled event: {} ({})", EventType, EventStatus); });
+ ZEN_WARN("Azure termination event detected: {} ({})", EventType, EventStatus);
+ }
+ return;
+ }
+ }
+}
+
+// GCP maintenance-event returns "NONE" when nothing is pending, and a
+// descriptive string like "TERMINATE_ON_HOST_MAINTENANCE" when the VM is
+// about to be live-migrated or terminated. Preemptible/spot VMs get a
+// 30-second warning before termination.
+void
+CloudMetadata::PollGCPTermination()
+{
+ ZEN_TRACE_CPU("CloudMetadata::PollGCPTermination");
+
+ HttpClient ImdsClient(m_ImdsEndpoint,
+ {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}});
+
+ HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"});
+
+ HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders);
+ if (MaintenanceResponse.IsSuccess())
+ {
+ std::string_view Event = MaintenanceResponse.AsText();
+ if (!Event.empty() && Event != "NONE")
+ {
+ if (!m_TerminationPending.exchange(true))
+ {
+ m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); });
+ ZEN_WARN("GCP maintenance event detected: {}", Event);
+ }
+ }
+ }
+}
+
+} // namespace zen::compute
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+# include <zencompute/mockimds.h>
+
+# include <zencore/filesystem.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zenhttp/httpserver.h>
+
+# include <memory>
+# include <thread>
+
+namespace zen::compute {
+
+TEST_SUITE_BEGIN("compute.cloudmetadata");
+
+// ---------------------------------------------------------------------------
+// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService
+// ---------------------------------------------------------------------------
+
+struct TestImdsServer
+{
+ MockImdsService Mock;
+
+ void Start()
+ {
+ m_TmpDir.emplace();
+ m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"});
+ m_Port = m_Server->Initialize(7575, m_TmpDir->Path() / "http");
+ REQUIRE(m_Port != -1);
+ m_Server->RegisterService(Mock);
+ m_ServerThread = std::thread([this]() { m_Server->Run(false); });
+ }
+
+ std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); }
+
+ std::filesystem::path DataDir() const { return m_TmpDir->Path() / "cloud"; }
+
+ std::unique_ptr<CloudMetadata> CreateCloud() { return std::make_unique<CloudMetadata>(DataDir(), Endpoint()); }
+
+ ~TestImdsServer()
+ {
+ if (m_Server)
+ {
+ m_Server->RequestExit();
+ }
+ if (m_ServerThread.joinable())
+ {
+ m_ServerThread.join();
+ }
+ if (m_Server)
+ {
+ m_Server->Close();
+ }
+ }
+
+private:
+ std::optional<ScopedTemporaryDirectory> m_TmpDir;
+ Ref<HttpServer> m_Server;
+ std::thread m_ServerThread;
+ int m_Port = -1;
+};
+
+// ---------------------------------------------------------------------------
+// AWS
+// ---------------------------------------------------------------------------
+
+TEST_CASE("cloudmetadata.aws")
+{
+ TestImdsServer Imds;
+ Imds.Mock.ActiveProvider = CloudProvider::AWS;
+
+ SUBCASE("detection basics")
+ {
+ Imds.Mock.Aws.InstanceId = "i-abc123";
+ Imds.Mock.Aws.AvailabilityZone = "us-west-2b";
+ Imds.Mock.Aws.LifeCycle = "on-demand";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(Cloud->GetProvider() == CloudProvider::AWS);
+
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.InstanceId == "i-abc123");
+ CHECK(Info.AvailabilityZone == "us-west-2b");
+ CHECK(Info.IsSpot == false);
+ CHECK(Info.IsAutoscaling == false);
+ CHECK(Cloud->IsTerminationPending() == false);
+ }
+
+ SUBCASE("spot instance")
+ {
+ Imds.Mock.Aws.LifeCycle = "spot";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.IsSpot == true);
+ }
+
+ SUBCASE("autoscaling instance")
+ {
+ Imds.Mock.Aws.AutoscalingState = "InService";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.IsAutoscaling == true);
+ }
+
+ SUBCASE("spot termination")
+ {
+ Imds.Mock.Aws.LifeCycle = "spot";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CHECK(Cloud->IsTerminationPending() == false);
+
+ // Simulate a spot interruption notice appearing
+ Imds.Mock.Aws.SpotAction = R"({"action":"terminate","time":"2025-01-01T00:00:00Z"})";
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == true);
+ CHECK(Cloud->GetTerminationReason().find("spot interruption") != std::string::npos);
+ }
+
+ SUBCASE("autoscaling termination")
+ {
+ Imds.Mock.Aws.AutoscalingState = "InService";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CHECK(Cloud->IsTerminationPending() == false);
+
+ // Simulate ASG cycling the instance out
+ Imds.Mock.Aws.AutoscalingState = "Terminated:Wait";
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == true);
+ CHECK(Cloud->GetTerminationReason().find("autoscaling") != std::string::npos);
+ }
+
+ SUBCASE("no termination when InService")
+ {
+ Imds.Mock.Aws.AutoscalingState = "InService";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == false);
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Azure
+// ---------------------------------------------------------------------------
+
+TEST_CASE("cloudmetadata.azure")
+{
+ TestImdsServer Imds;
+ Imds.Mock.ActiveProvider = CloudProvider::Azure;
+
+ SUBCASE("detection basics")
+ {
+ Imds.Mock.Azure.VmId = "vm-test-1234";
+ Imds.Mock.Azure.Location = "westeurope";
+ Imds.Mock.Azure.Priority = "Regular";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(Cloud->GetProvider() == CloudProvider::Azure);
+
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.InstanceId == "vm-test-1234");
+ CHECK(Info.AvailabilityZone == "westeurope");
+ CHECK(Info.IsSpot == false);
+ CHECK(Info.IsAutoscaling == false);
+ CHECK(Cloud->IsTerminationPending() == false);
+ }
+
+ SUBCASE("spot instance")
+ {
+ Imds.Mock.Azure.Priority = "Spot";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.IsSpot == true);
+ }
+
+ SUBCASE("vmss instance")
+ {
+ Imds.Mock.Azure.VmScaleSetName = "my-vmss";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.IsAutoscaling == true);
+ }
+
+ SUBCASE("preempt termination")
+ {
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CHECK(Cloud->IsTerminationPending() == false);
+
+ Imds.Mock.Azure.ScheduledEventType = "Preempt";
+ Imds.Mock.Azure.ScheduledEventStatus = "Scheduled";
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == true);
+ CHECK(Cloud->GetTerminationReason().find("Preempt") != std::string::npos);
+ }
+
+ SUBCASE("terminate event")
+ {
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CHECK(Cloud->IsTerminationPending() == false);
+
+ Imds.Mock.Azure.ScheduledEventType = "Terminate";
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == true);
+ CHECK(Cloud->GetTerminationReason().find("Terminate") != std::string::npos);
+ }
+
+ SUBCASE("no termination when events empty")
+ {
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == false);
+ }
+}
+
+// ---------------------------------------------------------------------------
+// GCP
+// ---------------------------------------------------------------------------
+
+TEST_CASE("cloudmetadata.gcp")
+{
+ TestImdsServer Imds;
+ Imds.Mock.ActiveProvider = CloudProvider::GCP;
+
+ SUBCASE("detection basics")
+ {
+ Imds.Mock.Gcp.InstanceId = "9876543210";
+ Imds.Mock.Gcp.Zone = "projects/123/zones/europe-west1-b";
+ Imds.Mock.Gcp.Preemptible = "FALSE";
+ Imds.Mock.Gcp.MaintenanceEvent = "NONE";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(Cloud->GetProvider() == CloudProvider::GCP);
+
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.InstanceId == "9876543210");
+ CHECK(Info.AvailabilityZone == "europe-west1-b"); // zone prefix stripped
+ CHECK(Info.IsSpot == false);
+ CHECK(Cloud->IsTerminationPending() == false);
+ }
+
+ SUBCASE("preemptible instance")
+ {
+ Imds.Mock.Gcp.Preemptible = "TRUE";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.IsSpot == true);
+ }
+
+ SUBCASE("maintenance event during detection")
+ {
+ Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+
+ // GCP sets termination pending immediately during detection if a
+ // maintenance event is active
+ CHECK(Cloud->IsTerminationPending() == true);
+ CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos);
+ }
+
+ SUBCASE("maintenance event during polling")
+ {
+ Imds.Mock.Gcp.MaintenanceEvent = "NONE";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ CHECK(Cloud->IsTerminationPending() == false);
+
+ Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE";
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == true);
+ CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos);
+ }
+
+ SUBCASE("no termination when NONE")
+ {
+ Imds.Mock.Gcp.MaintenanceEvent = "NONE";
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+ Cloud->PollTermination();
+
+ CHECK(Cloud->IsTerminationPending() == false);
+ }
+}
+
+// ---------------------------------------------------------------------------
+// No provider
+// ---------------------------------------------------------------------------
+
+TEST_CASE("cloudmetadata.no_provider")
+{
+ TestImdsServer Imds;
+ Imds.Mock.ActiveProvider = CloudProvider::None;
+ Imds.Start();
+
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(Cloud->GetProvider() == CloudProvider::None);
+
+ CloudInstanceInfo Info = Cloud->GetInstanceInfo();
+ CHECK(Info.InstanceId.empty());
+ CHECK(Info.AvailabilityZone.empty());
+ CHECK(Info.IsSpot == false);
+ CHECK(Info.IsAutoscaling == false);
+ CHECK(Cloud->IsTerminationPending() == false);
+}
+
+// ---------------------------------------------------------------------------
+// Sentinel file management
+// ---------------------------------------------------------------------------
+
+TEST_CASE("cloudmetadata.sentinel_files")
+{
+ TestImdsServer Imds;
+ Imds.Mock.ActiveProvider = CloudProvider::None;
+ Imds.Start();
+
+ auto DataDir = Imds.DataDir();
+
+ SUBCASE("sentinels are written on failed detection")
+ {
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(Cloud->GetProvider() == CloudProvider::None);
+ CHECK(zen::IsFile(DataDir / ".isNotAWS"));
+ CHECK(zen::IsFile(DataDir / ".isNotAzure"));
+ CHECK(zen::IsFile(DataDir / ".isNotGCP"));
+ }
+
+ SUBCASE("ClearSentinelFiles removes sentinels")
+ {
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(zen::IsFile(DataDir / ".isNotAWS"));
+ CHECK(zen::IsFile(DataDir / ".isNotAzure"));
+ CHECK(zen::IsFile(DataDir / ".isNotGCP"));
+
+ Cloud->ClearSentinelFiles();
+
+ CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS"));
+ CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure"));
+ CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP"));
+ }
+
+ SUBCASE("only failed providers get sentinels")
+ {
+ // Switch to AWS — Azure and GCP never probed, so no sentinels for them
+ Imds.Mock.ActiveProvider = CloudProvider::AWS;
+
+ auto Cloud = Imds.CreateCloud();
+
+ CHECK(Cloud->GetProvider() == CloudProvider::AWS);
+ CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS"));
+ CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure"));
+ CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP"));
+ }
+}
+
+TEST_SUITE_END();
+
+void
+cloudmetadata_forcelink()
+{
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
new file mode 100644
index 000000000..838d741b6
--- /dev/null
+++ b/src/zencompute/computeservice.cpp
@@ -0,0 +1,2236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencompute/computeservice.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "runners/functionrunner.h"
+# include "recording/actionrecorder.h"
+# include "runners/localrunner.h"
+# include "runners/remotehttprunner.h"
+# if ZEN_PLATFORM_LINUX
+# include "runners/linuxrunner.h"
+# elif ZEN_PLATFORM_WINDOWS
+# include "runners/windowsrunner.h"
+# elif ZEN_PLATFORM_MAC
+# include "runners/macrunner.h"
+# endif
+
+# include <zencompute/recordingreader.h>
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/trace.h>
+# include <zencore/workthreadpool.h>
+# include <zenutil/workerpools.h>
+# include <zentelemetry/stats.h>
+# include <zenhttp/httpclient.h>
+
+# include <set>
+# include <deque>
+# include <map>
+# include <thread>
+# include <unordered_map>
+# include <unordered_set>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <EASTL/hash_set.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+using namespace std::literals;
+
+namespace zen {
+
+const char*
+ToString(compute::ComputeServiceSession::SessionState State)
+{
+ using enum compute::ComputeServiceSession::SessionState;
+ switch (State)
+ {
+ case Created:
+ return "Created";
+ case Ready:
+ return "Ready";
+ case Draining:
+ return "Draining";
+ case Paused:
+ return "Paused";
+ case Abandoned:
+ return "Abandoned";
+ case Sunset:
+ return "Sunset";
+ }
+ return "Unknown";
+}
+
+const char*
+ToString(compute::ComputeServiceSession::QueueState State)
+{
+ using enum compute::ComputeServiceSession::QueueState;
+ switch (State)
+ {
+ case Active:
+ return "active";
+ case Draining:
+ return "draining";
+ case Cancelled:
+ return "cancelled";
+ }
+ return "unknown";
+}
+
+} // namespace zen
+
+namespace zen::compute {
+
+using SessionState = ComputeServiceSession::SessionState;
+
+static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count));
+
+//////////////////////////////////////////////////////////////////////////
+
+struct ComputeServiceSession::Impl
+{
+ ComputeServiceSession* m_ComputeServiceSession;
+ ChunkResolver& m_ChunkResolver;
+ LoggerRef m_Log{logging::Get("compute")};
+
+ Impl(ComputeServiceSession* InComputeServiceSession, ChunkResolver& InChunkResolver)
+ : m_ComputeServiceSession(InComputeServiceSession)
+ , m_ChunkResolver(InChunkResolver)
+ , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
+ , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
+ {
+ // Create a non-expiring, non-deletable implicit queue for legacy endpoints
+ auto Result = CreateQueue("implicit"sv, {}, {});
+ m_ImplicitQueueId = Result.QueueId;
+ m_QueueLock.WithSharedLock([&] { m_Queues[m_ImplicitQueueId]->Implicit = true; });
+
+ m_SchedulingThread = std::thread{&Impl::SchedulerThreadFunction, this};
+ }
+
+ void WaitUntilReady();
+ void Shutdown();
+ bool IsHealthy();
+
+ bool RequestStateTransition(SessionState NewState);
+ void AbandonAllActions();
+
+ LoggerRef Log() { return m_Log; }
+
+ // Orchestration
+
+ void SetOrchestratorEndpoint(std::string_view Endpoint);
+ void SetOrchestratorBasePath(std::filesystem::path BasePath);
+
+ std::string m_OrchestratorEndpoint;
+ std::filesystem::path m_OrchestratorBasePath;
+ Stopwatch m_OrchestratorQueryTimer;
+ std::unordered_set<std::string> m_KnownWorkerUris;
+
+ void UpdateCoordinatorState();
+
+ // Worker registration and discovery
+
+ struct FunctionDefinition
+ {
+ std::string FunctionName;
+ Guid FunctionVersion;
+ Guid BuildSystemVersion;
+ IoHash WorkerId;
+ };
+
+ void RegisterWorker(CbPackage Worker);
+ WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId);
+
+ // Action scheduling and tracking
+
+ std::atomic<SessionState> m_SessionState{SessionState::Created};
+ std::atomic<int32_t> m_ActionsCounter = 0; // sequence number
+ metrics::Meter m_ArrivalRate;
+
+ RwLock m_PendingLock;
+ std::map<int, Ref<RunnerAction>> m_PendingActions;
+
+ RwLock m_RunningLock;
+ std::unordered_map<int, Ref<RunnerAction>> m_RunningMap;
+
+ RwLock m_ResultsLock;
+ std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap;
+ metrics::Meter m_ResultRate;
+ std::atomic<uint64_t> m_RetiredCount{0};
+
+ EnqueueResult EnqueueAction(int QueueId, CbObject ActionObject, int Priority);
+ EnqueueResult EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority);
+
+ void GetCompleted(CbWriter& Cbo);
+
+ HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage);
+ HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage);
+ void RetireActionResult(int ActionLsn);
+
+ std::thread m_SchedulingThread;
+ std::atomic<bool> m_SchedulingThreadEnabled{true};
+ Event m_SchedulingThreadEvent;
+
+ void SchedulerThreadFunction();
+ void SchedulePendingActions();
+
+ // Workers
+
+ RwLock m_WorkerLock;
+ std::unordered_map<IoHash, CbPackage> m_WorkerMap;
+ std::vector<FunctionDefinition> m_FunctionList;
+ std::vector<IoHash> GetKnownWorkerIds();
+ void SyncWorkersToRunner(FunctionRunner& Runner);
+
+ // Runners
+
+ DeferredDirectoryDeleter m_DeferredDeleter;
+ WorkerThreadPool& m_LocalSubmitPool;
+ WorkerThreadPool& m_RemoteSubmitPool;
+ RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup;
+ RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup;
+
+ void ShutdownRunners();
+
+ // Recording
+
+ void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
+ void StopRecording();
+
+ std::unique_ptr<ActionRecorder> m_Recorder;
+
+ // History tracking
+
+ RwLock m_ActionHistoryLock;
+ std::deque<ComputeServiceSession::ActionHistoryEntry> m_ActionHistory;
+ size_t m_HistoryLimit = 1000;
+
+ // Queue tracking
+
+ using QueueState = ComputeServiceSession::QueueState;
+
+ struct QueueEntry : RefCounted
+ {
+ int QueueId;
+ bool Implicit{false};
+ std::atomic<QueueState> State{QueueState::Active};
+ std::atomic<int> ActiveCount{0}; // pending + running
+ std::atomic<int> CompletedCount{0}; // successfully completed
+ std::atomic<int> FailedCount{0}; // failed
+ std::atomic<int> AbandonedCount{0}; // abandoned
+ std::atomic<int> CancelledCount{0}; // cancelled
+ std::atomic<uint64_t> IdleSince{0}; // hifreq tick when queue became idle; 0 = has active work
+
+ RwLock m_Lock;
+ std::unordered_set<int> ActiveLsns; // for cancellation lookup
+ std::unordered_set<int> FinishedLsns; // completed/failed/cancelled LSNs
+
+ std::string Tag;
+ CbObject Metadata;
+ CbObject Config;
+ };
+
+ int m_ImplicitQueueId{0};
+ std::atomic<int> m_QueueCounter{0};
+ RwLock m_QueueLock;
+ std::unordered_map<int, Ref<QueueEntry>> m_Queues;
+
+ Ref<QueueEntry> FindQueue(int QueueId)
+ {
+ Ref<QueueEntry> Queue;
+ m_QueueLock.WithSharedLock([&] {
+ if (auto It = m_Queues.find(QueueId); It != m_Queues.end())
+ {
+ Queue = It->second;
+ }
+ });
+ return Queue;
+ }
+
+ ComputeServiceSession::CreateQueueResult CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config);
+ std::vector<int> GetQueueIds();
+ ComputeServiceSession::QueueStatus GetQueueStatus(int QueueId);
+ CbObject GetQueueMetadata(int QueueId);
+ CbObject GetQueueConfig(int QueueId);
+ void CancelQueue(int QueueId);
+ void DeleteQueue(int QueueId);
+ void DrainQueue(int QueueId);
+ ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority);
+ ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority);
+ void GetQueueCompleted(int QueueId, CbWriter& Cbo);
+ void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState);
+ void ExpireCompletedQueues();
+
+ Stopwatch m_QueueExpiryTimer;
+
+ std::vector<ComputeServiceSession::RunningActionInfo> GetRunningActions();
+ std::vector<ComputeServiceSession::ActionHistoryEntry> GetActionHistory(int Limit);
+ std::vector<ComputeServiceSession::ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit);
+
+ // Action submission
+
+ [[nodiscard]] size_t QueryCapacity();
+
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action);
+ [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+ [[nodiscard]] size_t GetSubmittedActionCount();
+
+ // Updates
+
+ RwLock m_UpdatedActionsLock;
+ std::vector<Ref<RunnerAction>> m_UpdatedActions;
+
+ void HandleActionUpdates();
+ void PostUpdate(RunnerAction* Action);
+
+ static constexpr int kDefaultMaxRetries = 3;
+ int GetMaxRetriesForQueue(int QueueId);
+
+ ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn);
+
+ ActionCounts GetActionCounts()
+ {
+ ActionCounts Counts;
+ Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+ Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
+ Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load();
+ Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] {
+ size_t Count = 0;
+ for (const auto& [Id, Queue] : m_Queues)
+ {
+ if (!Queue->Implicit)
+ {
+ ++Count;
+ }
+ }
+ return Count;
+ });
+ return Counts;
+ }
+
+ void EmitStats(CbObjectWriter& Cbo)
+ {
+ Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed));
+ m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); });
+ m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); });
+ m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); });
+ Cbo << "actions_submitted"sv << GetSubmittedActionCount();
+ EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo);
+ EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo);
+ }
+};
+
+bool
+ComputeServiceSession::Impl::IsHealthy()
+{
+ return m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned;
+}
+
+bool
+ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState)
+{
+ SessionState Current = m_SessionState.load(std::memory_order_relaxed);
+
+ for (;;)
+ {
+ if (Current == NewState)
+ {
+ return true;
+ }
+
+ // Validate the transition
+ bool Valid = false;
+
+ switch (Current)
+ {
+ case SessionState::Created:
+ Valid = (NewState == SessionState::Ready);
+ break;
+ case SessionState::Ready:
+ Valid = (NewState == SessionState::Draining);
+ break;
+ case SessionState::Draining:
+ Valid = (NewState == SessionState::Ready || NewState == SessionState::Paused);
+ break;
+ case SessionState::Paused:
+ Valid = (NewState == SessionState::Ready || NewState == SessionState::Sunset);
+ break;
+ case SessionState::Abandoned:
+ Valid = (NewState == SessionState::Sunset);
+ break;
+ case SessionState::Sunset:
+ Valid = false;
+ break;
+ }
+
+ // Allow jumping directly to Abandoned or Sunset from any non-terminal state
+ if (NewState == SessionState::Abandoned && Current < SessionState::Abandoned)
+ {
+ Valid = true;
+ }
+ if (NewState == SessionState::Sunset && Current != SessionState::Sunset)
+ {
+ Valid = true;
+ }
+
+ if (!Valid)
+ {
+ ZEN_WARN("invalid session state transition {} -> {}", ToString(Current), ToString(NewState));
+ return false;
+ }
+
+ if (m_SessionState.compare_exchange_strong(Current, NewState, std::memory_order_acq_rel))
+ {
+ ZEN_INFO("session state: {} -> {}", ToString(Current), ToString(NewState));
+
+ if (NewState == SessionState::Abandoned)
+ {
+ AbandonAllActions();
+ }
+
+ return true;
+ }
+
+ // CAS failed, Current was updated — retry with the new value
+ }
+}
+
+void
+ComputeServiceSession::Impl::AbandonAllActions()
+{
+ // Collect all pending actions and mark them as Abandoned
+ std::vector<Ref<RunnerAction>> PendingToAbandon;
+
+ m_PendingLock.WithSharedLock([&] {
+ PendingToAbandon.reserve(m_PendingActions.size());
+ for (auto& [Lsn, Action] : m_PendingActions)
+ {
+ PendingToAbandon.push_back(Action);
+ }
+ });
+
+ for (auto& Action : PendingToAbandon)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ }
+
+ // Collect all running actions and mark them as Abandoned, then
+ // best-effort cancel via the local runner group
+ std::vector<Ref<RunnerAction>> RunningToAbandon;
+
+ m_RunningLock.WithSharedLock([&] {
+ RunningToAbandon.reserve(m_RunningMap.size());
+ for (auto& [Lsn, Action] : m_RunningMap)
+ {
+ RunningToAbandon.push_back(Action);
+ }
+ });
+
+ for (auto& Action : RunningToAbandon)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ m_LocalRunnerGroup.CancelAction(Action->ActionLsn);
+ }
+
+ ZEN_INFO("abandoned all actions: {} pending, {} running", PendingToAbandon.size(), RunningToAbandon.size());
+}
+
+void
+ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint)
+{
+ m_OrchestratorEndpoint = Endpoint;
+}
+
+void
+ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BasePath)
+{
+ m_OrchestratorBasePath = std::move(BasePath);
+}
+
+void
+ComputeServiceSession::Impl::UpdateCoordinatorState()
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState");
+ if (m_OrchestratorEndpoint.empty())
+ {
+ return;
+ }
+
+ // Poll faster when we have no discovered workers yet so remote runners come online quickly
+ const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000;
+ if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs)
+ {
+ return;
+ }
+
+ m_OrchestratorQueryTimer.Reset();
+
+ try
+ {
+ HttpClient Client(m_OrchestratorEndpoint);
+
+ HttpClient::Response Response = Client.Get("/orch/agents");
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("orchestrator query failed with status {}", static_cast<int>(Response.StatusCode));
+ return;
+ }
+
+ CbObject WorkerList = Response.AsObject();
+
+ std::unordered_set<std::string> ValidWorkerUris;
+
+ for (auto& Item : WorkerList["workers"sv])
+ {
+ CbObjectView Worker = Item.AsObjectView();
+
+ uint64_t Dt = Worker["dt"sv].AsUInt64();
+ bool Reachable = Worker["reachable"sv].AsBool();
+ std::string_view Uri = Worker["uri"sv].AsString();
+
+ // Skip stale workers (not seen in over 30 seconds)
+ if (Dt > 30000)
+ {
+ continue;
+ }
+
+ // Skip workers that are not confirmed reachable
+ if (!Reachable)
+ {
+ continue;
+ }
+
+ std::string UriStr{Uri};
+ ValidWorkerUris.insert(UriStr);
+
+ // Skip workers we already know about
+ if (m_KnownWorkerUris.contains(UriStr))
+ {
+ continue;
+ }
+
+ ZEN_INFO("discovered new worker at {}", UriStr);
+
+ m_KnownWorkerUris.insert(UriStr);
+
+ auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool);
+ SyncWorkersToRunner(*NewRunner);
+ m_RemoteRunnerGroup.AddRunner(NewRunner);
+ }
+
+ // Remove workers that are no longer valid (stale or unreachable)
+ for (auto It = m_KnownWorkerUris.begin(); It != m_KnownWorkerUris.end();)
+ {
+ if (!ValidWorkerUris.contains(*It))
+ {
+ const std::string& ExpiredUri = *It;
+ ZEN_INFO("removing expired worker at {}", ExpiredUri);
+
+ m_RemoteRunnerGroup.RemoveRunnerIf([&](const RemoteHttpRunner& Runner) { return Runner.GetHostName() == ExpiredUri; });
+
+ It = m_KnownWorkerUris.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ }
+ catch (const HttpClientError& Ex)
+ {
+ ZEN_WARN("orchestrator query error: {}", Ex.what());
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("orchestrator query unexpected error: {}", Ex.what());
+ }
+}
+
+void
+ComputeServiceSession::Impl::WaitUntilReady()
+{
+ if (m_RemoteRunnerGroup.GetRunnerCount() || !m_OrchestratorEndpoint.empty())
+ {
+ ZEN_INFO("waiting for remote runners...");
+
+ constexpr int MaxWaitSeconds = 120;
+
+ for (int Elapsed = 0; Elapsed < MaxWaitSeconds; Elapsed++)
+ {
+ if (!m_SchedulingThreadEnabled.load(std::memory_order_relaxed))
+ {
+ ZEN_WARN("shutdown requested while waiting for remote runners");
+ return;
+ }
+
+ const size_t Capacity = m_RemoteRunnerGroup.QueryCapacity();
+
+ if (Capacity > 0)
+ {
+ ZEN_INFO("found {} remote runners (capacity: {})", m_RemoteRunnerGroup.GetRunnerCount(), Capacity);
+ break;
+ }
+
+ zen::Sleep(1000);
+ }
+ }
+ else
+ {
+ ZEN_ASSERT(m_LocalRunnerGroup.GetRunnerCount(), "no runners available");
+ }
+
+ RequestStateTransition(SessionState::Ready);
+}
+
+void
+ComputeServiceSession::Impl::Shutdown()
+{
+ RequestStateTransition(SessionState::Sunset);
+
+ m_SchedulingThreadEnabled = false;
+ m_SchedulingThreadEvent.Set();
+ if (m_SchedulingThread.joinable())
+ {
+ m_SchedulingThread.join();
+ }
+
+ ShutdownRunners();
+
+ m_DeferredDeleter.Shutdown();
+}
+
+void
+ComputeServiceSession::Impl::ShutdownRunners()
+{
+ m_LocalRunnerGroup.Shutdown();
+ m_RemoteRunnerGroup.Shutdown();
+}
+
+void
+ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath)
+{
+ ZEN_INFO("starting recording to '{}'", RecordingPath);
+
+ m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath);
+
+ ZEN_INFO("started recording to '{}'", RecordingPath);
+}
+
+void
+ComputeServiceSession::Impl::StopRecording()
+{
+ ZEN_INFO("stopping recording");
+
+ m_Recorder = nullptr;
+
+ ZEN_INFO("stopped recording");
+}
+
+std::vector<ComputeServiceSession::RunningActionInfo>
+ComputeServiceSession::Impl::GetRunningActions()
+{
+ std::vector<ComputeServiceSession::RunningActionInfo> Result;
+ m_RunningLock.WithSharedLock([&] {
+ Result.reserve(m_RunningMap.size());
+ for (const auto& [Lsn, Action] : m_RunningMap)
+ {
+ Result.push_back({.Lsn = Lsn,
+ .QueueId = Action->QueueId,
+ .ActionId = Action->ActionId,
+ .CpuUsagePercent = Action->CpuUsagePercent.load(std::memory_order_relaxed),
+ .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed)});
+ }
+ });
+ return Result;
+}
+
+std::vector<ComputeServiceSession::ActionHistoryEntry>
+ComputeServiceSession::Impl::GetActionHistory(int Limit)
+{
+ RwLock::SharedLockScope _(m_ActionHistoryLock);
+
+ if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size())
+ {
+ return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end());
+ }
+
+ return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end());
+}
+
+std::vector<ComputeServiceSession::ActionHistoryEntry>
+ComputeServiceSession::Impl::GetQueueHistory(int QueueId, int Limit)
+{
+ // Resolve the queue and snapshot its finished LSN set
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ return {};
+ }
+
+ std::unordered_set<int> FinishedLsns;
+
+ Queue->m_Lock.WithSharedLock([&] { FinishedLsns = Queue->FinishedLsns; });
+
+ // Filter the global history to entries belonging to this queue.
+ // m_ActionHistory is ordered oldest-first, so the filtered result keeps the same ordering.
+ std::vector<ActionHistoryEntry> Result;
+
+ m_ActionHistoryLock.WithSharedLock([&] {
+ for (const auto& Entry : m_ActionHistory)
+ {
+ if (FinishedLsns.contains(Entry.Lsn))
+ {
+ Result.push_back(Entry);
+ }
+ }
+ });
+
+ if (Limit > 0 && static_cast<size_t>(Limit) < Result.size())
+ {
+ Result.erase(Result.begin(), Result.end() - Limit);
+ }
+
+ return Result;
+}
+
+void
+ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker");
+ RwLock::ExclusiveLockScope _(m_WorkerLock);
+
+ const IoHash& WorkerId = Worker.GetObject().GetHash();
+
+ if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second)
+ {
+ // Note that since the convention currently is that WorkerId is equal to the hash
+ // of the worker descriptor there is no chance that we get a second write with a
+ // different descriptor. Thus we only need to call this the first time, when the
+ // worker is added
+
+ m_LocalRunnerGroup.RegisterWorker(Worker);
+ m_RemoteRunnerGroup.RegisterWorker(Worker);
+
+ if (m_Recorder)
+ {
+ m_Recorder->RegisterWorker(Worker);
+ }
+
+ CbObject WorkerObj = Worker.GetObject();
+
+ // Populate worker database
+
+ const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid();
+
+ for (auto& Item : WorkerObj["functions"sv])
+ {
+ CbObjectView Function = Item.AsObjectView();
+
+ std::string_view FunctionName = Function["name"sv].AsString();
+ const Guid FunctionVersion = Function["version"sv].AsUuid();
+
+ m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName},
+ .FunctionVersion = FunctionVersion,
+ .BuildSystemVersion = WorkerBuildSystemVersion,
+ .WorkerId = WorkerId});
+ }
+ }
+}
+
+void
+ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner)
+{
+ ZEN_TRACE_CPU("SyncWorkersToRunner");
+
+ std::vector<CbPackage> Workers;
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+ Workers.reserve(m_WorkerMap.size());
+ for (const auto& [Id, Pkg] : m_WorkerMap)
+ {
+ Workers.push_back(Pkg);
+ }
+ }
+
+ for (const CbPackage& Worker : Workers)
+ {
+ Runner.RegisterWorker(Worker);
+ }
+}
+
+WorkerDesc
+ComputeServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId)
+{
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end())
+ {
+ const CbPackage& Desc = It->second;
+ return {Desc, WorkerId};
+ }
+
+ return {};
+}
+
+std::vector<IoHash>
+ComputeServiceSession::Impl::GetKnownWorkerIds()
+{
+ std::vector<IoHash> WorkerIds;
+
+ m_WorkerLock.WithSharedLock([&] {
+ WorkerIds.reserve(m_WorkerMap.size());
+ for (const auto& [WorkerId, _] : m_WorkerMap)
+ {
+ WorkerIds.push_back(WorkerId);
+ }
+ });
+
+ return WorkerIds;
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::Impl::EnqueueAction(int QueueId, CbObject ActionObject, int Priority)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::EnqueueAction");
+
+ // Resolve function to worker
+
+ IoHash WorkerId{IoHash::Zero};
+ CbPackage WorkerPackage;
+
+ std::string_view FunctionName = ActionObject["Function"sv].AsString();
+ const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid();
+ const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid();
+
+ m_WorkerLock.WithSharedLock([&] {
+ for (const FunctionDefinition& FuncDef : m_FunctionList)
+ {
+ if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion &&
+ FuncDef.BuildSystemVersion == BuildSystemVersion)
+ {
+ WorkerId = FuncDef.WorkerId;
+
+ break;
+ }
+ }
+
+ if (WorkerId != IoHash::Zero)
+ {
+ if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end())
+ {
+ WorkerPackage = It->second;
+ }
+ }
+ });
+
+ if (WorkerId == IoHash::Zero)
+ {
+ CbObjectWriter Writer;
+
+ Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion;
+ Writer << "error"
+ << "no worker matches the action specification";
+
+ return {0, Writer.Save()};
+ }
+
+ if (WorkerPackage)
+ {
+ return EnqueueResolvedAction(QueueId, WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority);
+ }
+
+ CbObjectWriter Writer;
+
+ Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion;
+ Writer << "error"
+ << "no worker found despite match";
+
+ return {0, Writer.Save()};
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::EnqueueResolvedAction");
+
+ if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load()));
+ return {0, Writer.Save()};
+ }
+
+ const int ActionLsn = ++m_ActionsCounter;
+
+ m_ArrivalRate.Mark();
+
+ Ref<RunnerAction> Pending{new RunnerAction(m_ComputeServiceSession)};
+
+ Pending->ActionLsn = ActionLsn;
+ Pending->QueueId = QueueId;
+ Pending->Worker = Worker;
+ Pending->ActionId = ActionObj.GetHash();
+ Pending->ActionObj = ActionObj;
+ Pending->Priority = RequestPriority;
+
+ // For now simply put action into pending state, so we can do batch scheduling
+
+ ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn);
+
+ Pending->SetActionState(RunnerAction::State::Pending);
+
+ if (m_Recorder)
+ {
+ m_Recorder->RecordAction(Pending);
+ }
+
+ CbObjectWriter Writer;
+ Writer << "lsn" << Pending->ActionLsn;
+ Writer << "worker" << Pending->Worker.WorkerId;
+ Writer << "action" << Pending->ActionId;
+
+ return {Pending->ActionLsn, Writer.Save()};
+}
+
+SubmitResult
+ComputeServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action)
+{
+ // Loosely round-robin scheduling of actions across runners.
+ //
+ // It's not entirely clear what this means given that submits
+ // can come in across multiple threads, but it's probably better
+ // than always starting with the first runner.
+ //
+ // Longer term we should track the state of the individual
+ // runners and make decisions accordingly.
+
+ SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action);
+ if (Result.IsAccepted)
+ {
+ return Result;
+ }
+
+ return m_RemoteRunnerGroup.SubmitAction(Action);
+}
+
+size_t
+ComputeServiceSession::Impl::GetSubmittedActionCount()
+{
+ return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount();
+}
+
+HttpResponseCode
+ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage)
+{
+ // This lock is held for the duration of the function since we need to
+ // be sure that the action doesn't change state while we are checking the
+ // different data structures
+
+ RwLock::ExclusiveLockScope _(m_ResultsLock);
+
+ if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end())
+ {
+ OutResultPackage = std::move(It->second->GetResult());
+
+ m_ResultsMap.erase(It);
+
+ return HttpResponseCode::OK;
+ }
+
+ {
+ RwLock::SharedLockScope __(m_PendingLock);
+
+ if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end())
+ {
+ return HttpResponseCode::Accepted;
+ }
+ }
+
+ // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
+ // always be taken after m_ResultsLock if both are needed
+
+ {
+ RwLock::SharedLockScope __(m_RunningLock);
+
+ if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
+ {
+ return HttpResponseCode::Accepted;
+ }
+ }
+
+ return HttpResponseCode::NotFound;
+}
+
+HttpResponseCode
+ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage)
+{
+ // This lock is held for the duration of the function since we need to
+ // be sure that the action doesn't change state while we are checking the
+ // different data structures
+
+ RwLock::ExclusiveLockScope _(m_ResultsLock);
+
+ for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It)
+ {
+ if (It->second->ActionId == ActionId)
+ {
+ OutResultPackage = std::move(It->second->GetResult());
+
+ m_ResultsMap.erase(It);
+
+ return HttpResponseCode::OK;
+ }
+ }
+
+ {
+ RwLock::SharedLockScope __(m_PendingLock);
+
+ for (const auto& [K, Pending] : m_PendingActions)
+ {
+ if (Pending->ActionId == ActionId)
+ {
+ return HttpResponseCode::Accepted;
+ }
+ }
+ }
+
+ // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
+ // always be taken after m_ResultsLock if both are needed
+
+ {
+ RwLock::SharedLockScope __(m_RunningLock);
+
+ for (const auto& [K, v] : m_RunningMap)
+ {
+ if (v->ActionId == ActionId)
+ {
+ return HttpResponseCode::Accepted;
+ }
+ }
+ }
+
+ return HttpResponseCode::NotFound;
+}
+
+void
+ComputeServiceSession::Impl::RetireActionResult(int ActionLsn)
+{
+ m_DeferredDeleter.MarkReady(ActionLsn);
+}
+
+void
+ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo)
+{
+ Cbo.BeginArray("completed");
+
+ m_ResultsLock.WithSharedLock([&] {
+ for (auto& [Lsn, Action] : m_ResultsMap)
+ {
+ Cbo.BeginObject();
+ Cbo << "lsn"sv << Lsn;
+ Cbo << "state"sv << RunnerAction::ToString(Action->ActionState());
+ Cbo.EndObject();
+ }
+ });
+
+ Cbo.EndArray();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Queue management
+
+ComputeServiceSession::CreateQueueResult
+ComputeServiceSession::Impl::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config)
+{
+ const int QueueId = ++m_QueueCounter;
+
+ Ref<QueueEntry> Queue{new QueueEntry()};
+ Queue->QueueId = QueueId;
+ Queue->Tag = Tag;
+ Queue->Metadata = std::move(Metadata);
+ Queue->Config = std::move(Config);
+ Queue->IdleSince = GetHifreqTimerValue();
+
+ m_QueueLock.WithExclusiveLock([&] { m_Queues[QueueId] = Queue; });
+
+ ZEN_DEBUG("created queue {}", QueueId);
+
+ return {.QueueId = QueueId};
+}
+
+std::vector<int>
+ComputeServiceSession::Impl::GetQueueIds()
+{
+ std::vector<int> Ids;
+
+ m_QueueLock.WithSharedLock([&] {
+ Ids.reserve(m_Queues.size());
+ for (const auto& [Id, Queue] : m_Queues)
+ {
+ if (!Queue->Implicit)
+ {
+ Ids.push_back(Id);
+ }
+ }
+ });
+
+ return Ids;
+}
+
+ComputeServiceSession::QueueStatus
+ComputeServiceSession::Impl::GetQueueStatus(int QueueId)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ return {};
+ }
+
+ const int Active = Queue->ActiveCount.load(std::memory_order_relaxed);
+ const int Completed = Queue->CompletedCount.load(std::memory_order_relaxed);
+ const int Failed = Queue->FailedCount.load(std::memory_order_relaxed);
+ const int AbandonedN = Queue->AbandonedCount.load(std::memory_order_relaxed);
+ const int CancelledN = Queue->CancelledCount.load(std::memory_order_relaxed);
+ const QueueState QState = Queue->State.load();
+
+ return {
+ .IsValid = true,
+ .QueueId = QueueId,
+ .ActiveCount = Active,
+ .CompletedCount = Completed,
+ .FailedCount = Failed,
+ .AbandonedCount = AbandonedN,
+ .CancelledCount = CancelledN,
+ .State = QState,
+ .IsComplete = (Active == 0),
+ };
+}
+
+CbObject
+ComputeServiceSession::Impl::GetQueueMetadata(int QueueId)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ return {};
+ }
+
+ return Queue->Metadata;
+}
+
+CbObject
+ComputeServiceSession::Impl::GetQueueConfig(int QueueId)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ return {};
+ }
+
+ return Queue->Config;
+}
+
+void
+ComputeServiceSession::Impl::CancelQueue(int QueueId)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue || Queue->Implicit)
+ {
+ return;
+ }
+
+ Queue->State.store(QueueState::Cancelled);
+
+ // Collect active LSNs snapshot for cancellation
+ std::vector<int> LsnsToCancel;
+
+ Queue->m_Lock.WithSharedLock([&] { LsnsToCancel.assign(Queue->ActiveLsns.begin(), Queue->ActiveLsns.end()); });
+
+ // Identify which LSNs are still pending (not yet dispatched to a runner)
+ std::vector<Ref<RunnerAction>> PendingActionsToCancel;
+ std::vector<int> RunningLsnsToCancel;
+
+ m_PendingLock.WithSharedLock([&] {
+ for (int Lsn : LsnsToCancel)
+ {
+ if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end())
+ {
+ PendingActionsToCancel.push_back(It->second);
+ }
+ }
+ });
+
+ m_RunningLock.WithSharedLock([&] {
+ for (int Lsn : LsnsToCancel)
+ {
+ if (m_RunningMap.find(Lsn) != m_RunningMap.end())
+ {
+ RunningLsnsToCancel.push_back(Lsn);
+ }
+ }
+ });
+
+ // Cancel pending actions by marking them as Cancelled; they will flow through
+ // HandleActionUpdates and eventually be removed from the pending map.
+ for (auto& Action : PendingActionsToCancel)
+ {
+ Action->SetActionState(RunnerAction::State::Cancelled);
+ }
+
+ // Best-effort cancellation of running actions via the local runner group.
+ // Also set the action state to Cancelled directly so a subsequent Failed
+ // transition from the runner is blocked (Cancelled > Failed in the enum).
+ for (int Lsn : RunningLsnsToCancel)
+ {
+ m_RunningLock.WithSharedLock([&] {
+ if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end())
+ {
+ It->second->SetActionState(RunnerAction::State::Cancelled);
+ }
+ });
+ m_LocalRunnerGroup.CancelAction(Lsn);
+ }
+
+ m_RemoteRunnerGroup.CancelRemoteQueue(QueueId);
+
+ ZEN_INFO("cancelled queue {}: {} pending, {} running actions cancelled",
+ QueueId,
+ PendingActionsToCancel.size(),
+ RunningLsnsToCancel.size());
+
+ // Wake up the scheduler to process the cancelled actions
+ m_SchedulingThreadEvent.Set();
+}
+
+void
+ComputeServiceSession::Impl::DeleteQueue(int QueueId)
+{
+ // Never delete the implicit queue
+ {
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+ if (Queue && Queue->Implicit)
+ {
+ return;
+ }
+ }
+
+ // Cancel any active work first
+ CancelQueue(QueueId);
+
+ m_QueueLock.WithExclusiveLock([&] {
+ if (auto It = m_Queues.find(QueueId); It != m_Queues.end())
+ {
+ m_Queues.erase(It);
+ }
+ });
+}
+
+void
+ComputeServiceSession::Impl::DrainQueue(int QueueId)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue || Queue->Implicit)
+ {
+ return;
+ }
+
+ QueueState Expected = QueueState::Active;
+ Queue->State.compare_exchange_strong(Expected, QueueState::Draining);
+ ZEN_INFO("draining queue {}", QueueId);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv
+ << "queue not found"sv;
+ return {0, Writer.Save()};
+ }
+
+ QueueState QState = Queue->State.load();
+ if (QState == QueueState::Cancelled)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv
+ << "queue is cancelled"sv;
+ return {0, Writer.Save()};
+ }
+
+ if (QState == QueueState::Draining)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv
+ << "queue is draining"sv;
+ return {0, Writer.Save()};
+ }
+
+ EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority);
+
+ if (Result.Lsn != 0)
+ {
+ Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); });
+ Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
+ Queue->IdleSince.store(0, std::memory_order_relaxed);
+ }
+
+ return Result;
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv
+ << "queue not found"sv;
+ return {0, Writer.Save()};
+ }
+
+ QueueState QState = Queue->State.load();
+ if (QState == QueueState::Cancelled)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv
+ << "queue is cancelled"sv;
+ return {0, Writer.Save()};
+ }
+
+ if (QState == QueueState::Draining)
+ {
+ CbObjectWriter Writer;
+ Writer << "error"sv
+ << "queue is draining"sv;
+ return {0, Writer.Save()};
+ }
+
+ EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority);
+
+ if (Result.Lsn != 0)
+ {
+ Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); });
+ Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
+ Queue->IdleSince.store(0, std::memory_order_relaxed);
+ }
+
+ return Result;
+}
+
+void
+ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo)
+{
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ Cbo.BeginArray("completed");
+
+ if (Queue)
+ {
+ Queue->m_Lock.WithSharedLock([&] {
+ m_ResultsLock.WithSharedLock([&] {
+ for (int Lsn : Queue->FinishedLsns)
+ {
+ if (m_ResultsMap.contains(Lsn))
+ {
+ Cbo << Lsn;
+ }
+ }
+ });
+ });
+ }
+
+ Cbo.EndArray();
+}
+
+void
+ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState)
+{
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ Ref<QueueEntry> Queue = FindQueue(QueueId);
+
+ if (!Queue)
+ {
+ return;
+ }
+
+ Queue->m_Lock.WithExclusiveLock([&] {
+ Queue->ActiveLsns.erase(Lsn);
+ Queue->FinishedLsns.insert(Lsn);
+ });
+
+ const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
+ if (PreviousActive == 1)
+ {
+ Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ }
+
+ switch (ActionState)
+ {
+ case RunnerAction::State::Completed:
+ Queue->CompletedCount.fetch_add(1, std::memory_order_relaxed);
+ break;
+ case RunnerAction::State::Abandoned:
+ Queue->AbandonedCount.fetch_add(1, std::memory_order_relaxed);
+ break;
+ case RunnerAction::State::Cancelled:
+ Queue->CancelledCount.fetch_add(1, std::memory_order_relaxed);
+ break;
+ default:
+ Queue->FailedCount.fetch_add(1, std::memory_order_relaxed);
+ break;
+ }
+}
+
+void
+ComputeServiceSession::Impl::ExpireCompletedQueues()
+{
+ static constexpr uint64_t ExpiryTimeMs = 15 * 60 * 1000;
+
+ std::vector<int> ExpiredQueueIds;
+
+ m_QueueLock.WithSharedLock([&] {
+ for (const auto& [Id, Queue] : m_Queues)
+ {
+ if (Queue->Implicit)
+ {
+ continue;
+ }
+ const uint64_t Idle = Queue->IdleSince.load(std::memory_order_relaxed);
+ if (Idle != 0 && Queue->ActiveCount.load(std::memory_order_relaxed) == 0)
+ {
+ const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(GetHifreqTimerValue() - Idle);
+ if (ElapsedMs >= ExpiryTimeMs)
+ {
+ ExpiredQueueIds.push_back(Id);
+ }
+ }
+ }
+ });
+
+ for (int QueueId : ExpiredQueueIds)
+ {
+ ZEN_INFO("expiring idle queue {}", QueueId);
+ DeleteQueue(QueueId);
+ }
+}
+
+void
+ComputeServiceSession::Impl::SchedulePendingActions()
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions");
+ int ScheduledCount = 0;
+ size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
+ size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+ size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); });
+
+ static Stopwatch DumpRunningTimer;
+
+ auto _ = MakeGuard([&] {
+ ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
+ ScheduledCount,
+ RunningCount,
+ m_RetiredCount.load(),
+ PendingCount,
+ ResultCount);
+
+ if (DumpRunningTimer.GetElapsedTimeMs() > 30000)
+ {
+ DumpRunningTimer.Reset();
+
+ std::set<int> RunningList;
+ m_RunningLock.WithSharedLock([&] {
+ for (auto& [K, V] : m_RunningMap)
+ {
+ RunningList.insert(K);
+ }
+ });
+
+ ExtendableStringBuilder<1024> RunningString;
+ for (int i : RunningList)
+ {
+ if (RunningString.Size())
+ {
+ RunningString << ", ";
+ }
+
+ RunningString.Append(IntNum(i));
+ }
+
+ ZEN_INFO("running: {}", RunningString);
+ }
+ });
+
+ size_t Capacity = QueryCapacity();
+
+ if (!Capacity)
+ {
+ _.Dismiss();
+
+ return;
+ }
+
+ std::vector<Ref<RunnerAction>> ActionsToSchedule;
+
+ // Pull actions to schedule from the pending queue, we will
+ // try to submit these to the runner outside of the lock. Note
+ // that because of how the state transitions work it's not
+ // actually the case that all of these actions will still be
+ // pending by the time we try to submit them, but that's fine.
+ //
+ // Also note that the m_PendingActions list is not maintained
+ // here, that's done periodically in SchedulePendingActions()
+
+ m_PendingLock.WithExclusiveLock([&] {
+ if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
+ {
+ return;
+ }
+
+ if (m_PendingActions.empty())
+ {
+ return;
+ }
+
+ for (auto& [Lsn, Pending] : m_PendingActions)
+ {
+ switch (Pending->ActionState())
+ {
+ case RunnerAction::State::Pending:
+ ActionsToSchedule.push_back(Pending);
+ break;
+
+ case RunnerAction::State::Submitting:
+ break; // already claimed by async submission
+
+ case RunnerAction::State::Running:
+ case RunnerAction::State::Completed:
+ case RunnerAction::State::Failed:
+ case RunnerAction::State::Abandoned:
+ case RunnerAction::State::Cancelled:
+ break;
+
+ default:
+ case RunnerAction::State::New:
+ ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn);
+ break;
+ }
+ }
+
+ // Sort by priority descending, then by LSN ascending (FIFO within same priority)
+ std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
+ if (A->Priority != B->Priority)
+ {
+ return A->Priority > B->Priority;
+ }
+ return A->ActionLsn < B->ActionLsn;
+ });
+
+ if (ActionsToSchedule.size() > Capacity)
+ {
+ ActionsToSchedule.resize(Capacity);
+ }
+
+ PendingCount = m_PendingActions.size();
+ });
+
+ if (ActionsToSchedule.empty())
+ {
+ _.Dismiss();
+ return;
+ }
+
+ ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size());
+
+ Stopwatch SubmitTimer;
+ std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule);
+
+ int NotAcceptedCount = 0;
+ int ScheduledActionCount = 0;
+
+ for (const SubmitResult& SubResult : SubmitResults)
+ {
+ if (SubResult.IsAccepted)
+ {
+ ++ScheduledActionCount;
+ }
+ else
+ {
+ ++NotAcceptedCount;
+ }
+ }
+
+ ZEN_INFO("scheduled {} pending actions in {} ({} rejected)",
+ ScheduledActionCount,
+ NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
+ NotAcceptedCount);
+
+ ScheduledCount += ScheduledActionCount;
+ PendingCount -= ScheduledActionCount;
+}
+
+void
+ComputeServiceSession::Impl::SchedulerThreadFunction()
+{
+ SetCurrentThreadName("Function_Scheduler");
+
+ auto _ = MakeGuard([&] { ZEN_INFO("scheduler thread exiting"); });
+
+ do
+ {
+ int TimeoutMs = 500;
+
+ auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+
+ if (PendingCount)
+ {
+ TimeoutMs = 100;
+ }
+
+ const bool WasSignaled = m_SchedulingThreadEvent.Wait(TimeoutMs);
+
+ if (m_SchedulingThreadEnabled == false)
+ {
+ return;
+ }
+
+ if (WasSignaled)
+ {
+ m_SchedulingThreadEvent.Reset();
+ }
+
+ ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}",
+ m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }),
+ PendingCount,
+ m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }),
+ m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }),
+ TimeoutMs);
+
+ HandleActionUpdates();
+
+ // Auto-transition Draining → Paused when all work is done
+ if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining)
+ {
+ size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+ size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
+
+ if (Pending == 0 && Running == 0)
+ {
+ SessionState Expected = SessionState::Draining;
+ if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel))
+ {
+ ZEN_INFO("session state: Draining -> Paused (all work completed)");
+ }
+ }
+ }
+
+ UpdateCoordinatorState();
+ SchedulePendingActions();
+
+ static constexpr uint64_t QueueExpirySweepIntervalMs = 30000;
+ if (m_QueueExpiryTimer.GetElapsedTimeMs() >= QueueExpirySweepIntervalMs)
+ {
+ m_QueueExpiryTimer.Reset();
+ ExpireCompletedQueues();
+ }
+ } while (m_SchedulingThreadEnabled);
+}
+
+void
+ComputeServiceSession::Impl::PostUpdate(RunnerAction* Action)
+{
+ m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); });
+ m_SchedulingThreadEvent.Set();
+}
+
+int
+ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId)
+{
+ if (QueueId == 0)
+ {
+ return kDefaultMaxRetries;
+ }
+
+ CbObject Config = GetQueueConfig(QueueId);
+
+ if (Config)
+ {
+ int Value = Config["max_retries"].AsInt32(0);
+
+ if (Value > 0)
+ {
+ return Value;
+ }
+ }
+
+ return kDefaultMaxRetries;
+}
+
+ComputeServiceSession::RescheduleResult
+ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
+{
+ Ref<RunnerAction> Action;
+ RunnerAction::State State;
+ RescheduleResult ValidationError;
+ bool Removed = false;
+
+ // Find, validate, and remove atomically under a single lock scope to prevent
+ // concurrent RescheduleAction calls from double-removing the same action.
+ m_ResultsLock.WithExclusiveLock([&] {
+ auto It = m_ResultsMap.find(ActionLsn);
+ if (It == m_ResultsMap.end())
+ {
+ ValidationError = {.Success = false, .Error = "Action not found in results"};
+ return;
+ }
+
+ Action = It->second;
+ State = Action->ActionState();
+
+ if (State != RunnerAction::State::Failed && State != RunnerAction::State::Abandoned)
+ {
+ ValidationError = {.Success = false, .Error = "Action is not in a failed or abandoned state"};
+ return;
+ }
+
+ int MaxRetries = GetMaxRetriesForQueue(Action->QueueId);
+ if (Action->RetryCount.load(std::memory_order_relaxed) >= MaxRetries)
+ {
+ ValidationError = {.Success = false, .Error = "Retry limit reached"};
+ return;
+ }
+
+ m_ResultsMap.erase(It);
+ Removed = true;
+ });
+
+ if (!Removed)
+ {
+ return ValidationError;
+ }
+
+ if (Action->QueueId != 0)
+ {
+ Ref<QueueEntry> Queue = FindQueue(Action->QueueId);
+
+ if (Queue)
+ {
+ Queue->m_Lock.WithExclusiveLock([&] {
+ Queue->FinishedLsns.erase(ActionLsn);
+ Queue->ActiveLsns.insert(ActionLsn);
+ });
+
+ Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
+ Queue->IdleSince.store(0, std::memory_order_relaxed);
+
+ if (State == RunnerAction::State::Failed)
+ {
+ Queue->FailedCount.fetch_sub(1, std::memory_order_relaxed);
+ }
+ else
+ {
+ Queue->AbandonedCount.fetch_sub(1, std::memory_order_relaxed);
+ }
+ }
+ }
+
+ // Reset action state — this calls PostUpdate() internally
+ Action->ResetActionStateToPending();
+
+ int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed);
+ ZEN_INFO("action {} ({}) manually rescheduled (retry {})", Action->ActionId, ActionLsn, NewRetryCount);
+
+ return {.Success = true, .RetryCount = NewRetryCount};
+}
+
+void
+ComputeServiceSession::Impl::HandleActionUpdates()
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::HandleActionUpdates");
+
+ // Drain the update queue atomically
+ std::vector<Ref<RunnerAction>> UpdatedActions;
+ m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); });
+
+ std::unordered_set<int> SeenLsn;
+
+ // Process each action's latest state, deduplicating by LSN.
+ //
+ // This is safe because state transitions are monotonically increasing by enum
+ // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so
+ // SetActionState rejects any transition to a lower-ranked state. By the time
+ // we read ActionState() here, it reflects the highest state reached — making
+ // the first occurrence per LSN authoritative and duplicates redundant.
+ for (Ref<RunnerAction>& Action : UpdatedActions)
+ {
+ const int ActionLsn = Action->ActionLsn;
+
+ if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted)
+ {
+ switch (Action->ActionState())
+ {
+ // Newly enqueued — add to pending map for scheduling
+ case RunnerAction::State::Pending:
+ m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ break;
+
+ // Async submission in progress — remains in pending map
+ case RunnerAction::State::Submitting:
+ break;
+
+ // Dispatched to a runner — move from pending to running
+ case RunnerAction::State::Running:
+ m_RunningLock.WithExclusiveLock([&] {
+ m_PendingLock.WithExclusiveLock([&] {
+ m_RunningMap.insert({ActionLsn, Action});
+ m_PendingActions.erase(ActionLsn);
+ });
+ });
+ ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn);
+ break;
+
+ // Terminal states — move to results, record history, notify queue
+ case RunnerAction::State::Completed:
+ case RunnerAction::State::Failed:
+ case RunnerAction::State::Abandoned:
+ case RunnerAction::State::Cancelled:
+ {
+ auto TerminalState = Action->ActionState();
+
+ // Automatic retry for Failed/Abandoned actions with retries remaining.
+ // Skip retries when the session itself is abandoned — those actions
+ // were intentionally abandoned and should not be rescheduled.
+ if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) &&
+ m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned)
+ {
+ int MaxRetries = GetMaxRetriesForQueue(Action->QueueId);
+
+ if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries)
+ {
+ // Remove from whichever active map the action is in before resetting
+ m_RunningLock.WithExclusiveLock([&] {
+ m_PendingLock.WithExclusiveLock([&] {
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
+ });
+ });
+
+ // Reset triggers PostUpdate() which re-enters the action as Pending
+ Action->ResetActionStateToPending();
+ int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed);
+
+ ZEN_INFO("action {} ({}) auto-rescheduled (retry {}/{})",
+ Action->ActionId,
+ ActionLsn,
+ NewRetryCount,
+ MaxRetries);
+ break;
+ }
+ }
+
+ // Remove from whichever active map the action is in
+ m_RunningLock.WithExclusiveLock([&] {
+ m_PendingLock.WithExclusiveLock([&] {
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
+ });
+ });
+
+ m_ResultsLock.WithExclusiveLock([&] {
+ m_ResultsMap[ActionLsn] = Action;
+
+ // Append to bounded action history ring
+ m_ActionHistoryLock.WithExclusiveLock([&] {
+ ActionHistoryEntry Entry{.Lsn = ActionLsn,
+ .QueueId = Action->QueueId,
+ .ActionId = Action->ActionId,
+ .WorkerId = Action->Worker.WorkerId,
+ .ActionDescriptor = Action->ActionObj,
+ .ExecutionLocation = std::move(Action->ExecutionLocation),
+ .Succeeded = TerminalState == RunnerAction::State::Completed,
+ .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed),
+ .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)};
+
+ std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps));
+
+ m_ActionHistory.push_back(std::move(Entry));
+
+ if (m_ActionHistory.size() > m_HistoryLimit)
+ {
+ m_ActionHistory.pop_front();
+ }
+ });
+ });
+ m_RetiredCount.fetch_add(1);
+ m_ResultRate.Mark(1);
+ ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}",
+ Action->ActionId,
+ ActionLsn,
+ TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE");
+ NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
+ break;
+ }
+ }
+ }
+ }
+}
+
+size_t
+ComputeServiceSession::Impl::QueryCapacity()
+{
+ return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity();
+}
+
+std::vector<SubmitResult>
+ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions");
+ std::vector<SubmitResult> Results(Actions.size());
+
+ // First try submitting the batch to local runners in parallel
+
+ std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions);
+ std::vector<size_t> RemoteIndices;
+ std::vector<Ref<RunnerAction>> RemoteActions;
+
+ for (size_t i = 0; i < Actions.size(); ++i)
+ {
+ if (LocalResults[i].IsAccepted)
+ {
+ Results[i] = std::move(LocalResults[i]);
+ }
+ else
+ {
+ RemoteIndices.push_back(i);
+ RemoteActions.push_back(Actions[i]);
+ }
+ }
+
+ // Submit remaining actions to remote runners in parallel
+ if (!RemoteActions.empty())
+ {
+ std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
+
+ for (size_t j = 0; j < RemoteIndices.size(); ++j)
+ {
+ Results[RemoteIndices[j]] = std::move(RemoteResults[j]);
+ }
+ }
+
+ return Results;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ComputeServiceSession::ComputeServiceSession(ChunkResolver& InChunkResolver)
+{
+ m_Impl = std::make_unique<Impl>(this, InChunkResolver);
+}
+
+ComputeServiceSession::~ComputeServiceSession()
+{
+ Shutdown();
+}
+
+bool
+ComputeServiceSession::IsHealthy()
+{
+ return m_Impl->IsHealthy();
+}
+
+void
+ComputeServiceSession::WaitUntilReady()
+{
+ m_Impl->WaitUntilReady();
+}
+
+void
+ComputeServiceSession::Shutdown()
+{
+ m_Impl->Shutdown();
+}
+
+ComputeServiceSession::SessionState
+ComputeServiceSession::GetSessionState() const
+{
+ return m_Impl->m_SessionState.load(std::memory_order_relaxed);
+}
+
+bool
+ComputeServiceSession::RequestStateTransition(SessionState NewState)
+{
+ return m_Impl->RequestStateTransition(NewState);
+}
+
+void
+ComputeServiceSession::SetOrchestratorEndpoint(std::string_view Endpoint)
+{
+ m_Impl->SetOrchestratorEndpoint(Endpoint);
+}
+
+void
+ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath)
+{
+ m_Impl->SetOrchestratorBasePath(std::move(BasePath));
+}
+
+void
+ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath)
+{
+ m_Impl->StartRecording(InResolver, RecordingPath);
+}
+
+void
+ComputeServiceSession::StopRecording()
+{
+ m_Impl->StopRecording();
+}
+
+ComputeServiceSession::ActionCounts
+ComputeServiceSession::GetActionCounts()
+{
+ return m_Impl->GetActionCounts();
+}
+
+void
+ComputeServiceSession::EmitStats(CbObjectWriter& Cbo)
+{
+ m_Impl->EmitStats(Cbo);
+}
+
+std::vector<IoHash>
+ComputeServiceSession::GetKnownWorkerIds()
+{
+ return m_Impl->GetKnownWorkerIds();
+}
+
+WorkerDesc
+ComputeServiceSession::GetWorkerDescriptor(const IoHash& WorkerId)
+{
+ return m_Impl->GetWorkerDescriptor(WorkerId);
+}
+
+void
+ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::AddLocalRunner");
+
+# if ZEN_PLATFORM_LINUX
+ auto* NewRunner = new LinuxProcessRunner(InChunkResolver,
+ BasePath,
+ m_Impl->m_DeferredDeleter,
+ m_Impl->m_LocalSubmitPool,
+ false,
+ MaxConcurrentActions);
+# elif ZEN_PLATFORM_WINDOWS
+ auto* NewRunner = new WindowsProcessRunner(InChunkResolver,
+ BasePath,
+ m_Impl->m_DeferredDeleter,
+ m_Impl->m_LocalSubmitPool,
+ false,
+ MaxConcurrentActions);
+# elif ZEN_PLATFORM_MAC
+ auto* NewRunner =
+ new MacProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, false, MaxConcurrentActions);
+# endif
+
+ m_Impl->SyncWorkersToRunner(*NewRunner);
+ m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner);
+}
+
+void
+ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner");
+
+ auto* NewRunner = new RemoteHttpRunner(InChunkResolver, BasePath, HostName, m_Impl->m_RemoteSubmitPool);
+ m_Impl->SyncWorkersToRunner(*NewRunner);
+ m_Impl->m_RemoteRunnerGroup.AddRunner(NewRunner);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::EnqueueAction(CbObject ActionObject, int Priority)
+{
+ return m_Impl->EnqueueActionToQueue(m_Impl->m_ImplicitQueueId, ActionObject, Priority);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority)
+{
+ return m_Impl->EnqueueResolvedActionToQueue(m_Impl->m_ImplicitQueueId, Worker, ActionObj, RequestPriority);
+}
+ComputeServiceSession::CreateQueueResult
+ComputeServiceSession::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config)
+{
+ return m_Impl->CreateQueue(Tag, std::move(Metadata), std::move(Config));
+}
+
+CbObject
+ComputeServiceSession::GetQueueMetadata(int QueueId)
+{
+ return m_Impl->GetQueueMetadata(QueueId);
+}
+
+CbObject
+ComputeServiceSession::GetQueueConfig(int QueueId)
+{
+ return m_Impl->GetQueueConfig(QueueId);
+}
+
+std::vector<int>
+ComputeServiceSession::GetQueueIds()
+{
+ return m_Impl->GetQueueIds();
+}
+
+ComputeServiceSession::QueueStatus
+ComputeServiceSession::GetQueueStatus(int QueueId)
+{
+ return m_Impl->GetQueueStatus(QueueId);
+}
+
+void
+ComputeServiceSession::CancelQueue(int QueueId)
+{
+ m_Impl->CancelQueue(QueueId);
+}
+
+void
+ComputeServiceSession::DrainQueue(int QueueId)
+{
+ m_Impl->DrainQueue(QueueId);
+}
+
+void
+ComputeServiceSession::DeleteQueue(int QueueId)
+{
+ m_Impl->DeleteQueue(QueueId);
+}
+
+void
+ComputeServiceSession::GetQueueCompleted(int QueueId, CbWriter& Cbo)
+{
+ m_Impl->GetQueueCompleted(QueueId, Cbo);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority)
+{
+ return m_Impl->EnqueueActionToQueue(QueueId, ActionObject, Priority);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority)
+{
+ return m_Impl->EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority);
+}
+
+void
+ComputeServiceSession::RegisterWorker(CbPackage Worker)
+{
+ m_Impl->RegisterWorker(Worker);
+}
+
+HttpResponseCode
+ComputeServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage)
+{
+ return m_Impl->GetActionResult(ActionLsn, OutResultPackage);
+}
+
+HttpResponseCode
+ComputeServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage)
+{
+ return m_Impl->FindActionResult(ActionId, OutResultPackage);
+}
+
+void
+ComputeServiceSession::RetireActionResult(int ActionLsn)
+{
+ m_Impl->RetireActionResult(ActionLsn);
+}
+
+ComputeServiceSession::RescheduleResult
+ComputeServiceSession::RescheduleAction(int ActionLsn)
+{
+ return m_Impl->RescheduleAction(ActionLsn);
+}
+
+std::vector<ComputeServiceSession::RunningActionInfo>
+ComputeServiceSession::GetRunningActions()
+{
+ return m_Impl->GetRunningActions();
+}
+
+std::vector<ComputeServiceSession::ActionHistoryEntry>
+ComputeServiceSession::GetActionHistory(int Limit)
+{
+ return m_Impl->GetActionHistory(Limit);
+}
+
+std::vector<ComputeServiceSession::ActionHistoryEntry>
+ComputeServiceSession::GetQueueHistory(int QueueId, int Limit)
+{
+ return m_Impl->GetQueueHistory(QueueId, Limit);
+}
+
+void
+ComputeServiceSession::GetCompleted(CbWriter& Cbo)
+{
+ m_Impl->GetCompleted(Cbo);
+}
+
+void
+ComputeServiceSession::PostUpdate(RunnerAction* Action)
+{
+ m_Impl->PostUpdate(Action);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+computeservice_forcelink()
+{
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
new file mode 100644
index 000000000..e82a40781
--- /dev/null
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -0,0 +1,1643 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencompute/httpcomputeservice.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "runners/functionrunner.h"
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/logging.h>
+# include <zencore/system.h>
+# include <zencore/thread.h>
+# include <zencore/trace.h>
+# include <zencore/uid.h>
+# include <zenstore/cidstore.h>
+# include <zentelemetry/stats.h>
+
+# include <span>
+# include <unordered_map>
+
+using namespace std::literals;
+
+namespace zen::compute {
+
+constinit AsciiSet g_DecimalSet("0123456789");
+constinit AsciiSet g_HexSet("0123456789abcdefABCDEF");
+
+auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); };
+auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); };
+auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSet::HasOnly(Str, g_HexSet); };
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpComputeService::Impl
+{
+ HttpComputeService* m_Self;
+ CidStore& m_CidStore;
+ IHttpStatsService& m_StatsService;
+ LoggerRef m_Log;
+ std::filesystem::path m_BaseDir;
+ HttpRequestRouter m_Router;
+ ComputeServiceSession m_ComputeService;
+ SystemMetricsTracker m_MetricsTracker;
+
+ // Metrics
+
+ metrics::OperationTiming m_HttpRequests;
+
+ // Per-remote-queue metadata, shared across all lookup maps below.
+
+ struct RemoteQueueInfo : RefCounted
+ {
+ int QueueId = 0;
+ Oid Token;
+ std::string IdempotencyKey; // empty if no idempotency key was provided
+ std::string ClientHostname; // empty if no hostname was provided
+ };
+
+ // Remote queue registry — all three maps share the same RemoteQueueInfo objects.
+ // All maps are guarded by m_RemoteQueueLock.
+
+ RwLock m_RemoteQueueLock;
+ std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info
+ std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info
+ std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info
+
+ LoggerRef Log() { return m_Log; }
+
+ int ResolveQueueToken(const Oid& Token);
+ int ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture);
+
+ struct IngestStats
+ {
+ int Count = 0;
+ int NewCount = 0;
+ uint64_t Bytes = 0;
+ uint64_t NewBytes = 0;
+ };
+
+ IngestStats IngestPackageAttachments(const CbPackage& Package);
+ bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
+ void HandleWorkersGet(HttpServerRequest& HttpReq);
+ void HandleWorkersAllGet(HttpServerRequest& HttpReq);
+ void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
+ void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
+
+ void RegisterRoutes();
+
+ Impl(HttpComputeService* Self,
+ CidStore& InCidStore,
+ IHttpStatsService& StatsService,
+ const std::filesystem::path& BaseDir,
+ int32_t MaxConcurrentActions)
+ : m_Self(Self)
+ , m_CidStore(InCidStore)
+ , m_StatsService(StatsService)
+ , m_Log(logging::Get("compute"))
+ , m_BaseDir(BaseDir)
+ , m_ComputeService(InCidStore)
+ {
+ m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions);
+ m_ComputeService.WaitUntilReady();
+ m_StatsService.RegisterHandler("compute", *m_Self);
+ RegisterRoutes();
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+HttpComputeService::Impl::RegisterRoutes()
+{
+ m_Router.AddMatcher("lsn", DecimalMatcher);
+ m_Router.AddMatcher("worker", IoHashMatcher);
+ m_Router.AddMatcher("action", IoHashMatcher);
+ m_Router.AddMatcher("queue", DecimalMatcher);
+ m_Router.AddMatcher("oidtoken", OidMatcher);
+ m_Router.AddMatcher("queueref", [](std::string_view Str) { return DecimalMatcher(Str) || OidMatcher(Str); });
+
+ m_Router.RegisterRoute(
+ "ready",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.IsHealthy())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok");
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "abandon",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (!HttpReq.IsLocalMachineRequest())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned);
+
+ if (Success)
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv
+ << "Abandoned"sv;
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Abandoned from current state"sv;
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "workers",
+ [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "workers/{worker}",
+ [this](HttpRouterRequest& Req) { HandleWorkerRequest(Req.ServerRequest(), IoHash::FromHexString(Req.GetCapture(1))); },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/completed",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ m_ComputeService.GetCompleted(Cbo);
+
+ ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query());
+ Cbo.BeginObject("metrics");
+ Describe(Sm, Cbo);
+ Cbo.EndObject();
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "jobs/history",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto QueryParams = HttpReq.GetQueryParams();
+
+ int QueryLimit = 50;
+
+ if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false)
+ {
+ QueryLimit = ParseInt<int>(LimitParam).value_or(50);
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("history");
+ for (const auto& Entry : m_ComputeService.GetActionHistory(QueryLimit))
+ {
+ Cbo.BeginObject();
+ Cbo << "lsn"sv << Entry.Lsn;
+ Cbo << "queueId"sv << Entry.QueueId;
+ Cbo << "actionId"sv << Entry.ActionId;
+ Cbo << "workerId"sv << Entry.WorkerId;
+ Cbo << "succeeded"sv << Entry.Succeeded;
+ Cbo << "actionDescriptor"sv << Entry.ActionDescriptor;
+ if (Entry.CpuSeconds > 0.0f)
+ {
+ Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds);
+ }
+ if (Entry.RetryCount > 0)
+ {
+ Cbo << "retry_count"sv << Entry.RetryCount;
+ }
+
+ for (const auto& Timestamp : Entry.Timestamps)
+ {
+ Cbo.AddInteger(
+ fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))),
+ Timestamp);
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "jobs/running",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Running = m_ComputeService.GetRunningActions();
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("running");
+ for (const auto& Info : Running)
+ {
+ Cbo.BeginObject();
+ Cbo << "lsn"sv << Info.Lsn;
+ Cbo << "queueId"sv << Info.QueueId;
+ Cbo << "actionId"sv << Info.ActionId;
+ if (Info.CpuUsagePercent >= 0.0f)
+ {
+ Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent);
+ }
+ if (Info.CpuSeconds > 0.0f)
+ {
+ Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds);
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "jobs/{lsn}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0);
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ CbPackage Output;
+ HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output);
+
+ if (ResponseCode == HttpResponseCode::OK)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ else
+ {
+ HttpReq.WriteResponse(ResponseCode);
+ }
+
+ // Once we've initiated the response we can mark the result
+ // as retired, allowing the service to free any associated
+ // resources. Note that there still needs to be a delay
+ // to allow the transmission to complete, it would be better
+ // if we could issue this once the response is fully sent...
+ m_ComputeService.RetireActionResult(ActionLsn);
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ auto Result = m_ComputeService.RescheduleAction(ActionLsn);
+
+ CbObjectWriter Cbo;
+ if (Result.Success)
+ {
+ Cbo << "lsn"sv << ActionLsn;
+ Cbo << "retry_count"sv << Result.RetryCount;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ else
+ {
+ Cbo << "error"sv << Result.Error;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the
+ // one which uses the scheduled action lsn for lookups
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2));
+
+ CbPackage Output;
+ if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output);
+ ResponseCode != HttpResponseCode::OK)
+ {
+ ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
+
+ if (ResponseCode == HttpResponseCode::NotFound)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+
+ ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2))
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "jobs/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId);
+
+ if (!Worker)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const auto QueryParams = Req.ServerRequest().GetQueryParams();
+
+ int RequestPriority = -1;
+
+ if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false)
+ {
+ RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ // TODO: return status of all pending or executing jobs
+ break;
+
+ case HttpVerb::kPost:
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ // This operation takes the proposed job spec and identifies which
+ // chunks are not present on this server. This list is then returned in
+ // the "need" list in the response
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject ActionObj = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash FileHash = Field.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ NeedList.push_back(FileHash);
+ }
+ });
+
+ if (NeedList.empty())
+ {
+ // We already have everything, enqueue the action for execution
+
+ if (ComputeServiceSession::EnqueueResult Result =
+ m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+
+ return;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+ CbObject Response = Cbo.Save();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
+ }
+ break;
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Action = HttpReq.ReadPayloadPackage();
+ CbObject ActionObj = Action.GetObject();
+
+ std::span<const CbAttachment> Attachments = Action.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
+
+ TotalAttachmentBytes += CompressedSize;
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += CompressedSize;
+ ++NewAttachmentCount;
+ }
+ }
+
+ if (ComputeServiceSession::EnqueueResult Result =
+ m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
+ ActionObj.GetHash(),
+ Result.Lsn,
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+
+ return;
+ }
+ break;
+
+ default:
+ break;
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto QueryParams = HttpReq.GetQueryParams();
+
+ int RequestPriority = -1;
+
+ if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false)
+ {
+ RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
+ }
+
+ // Resolve worker
+
+ //
+
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ // This operation takes the proposed job spec and identifies which
+ // chunks are not present on this server. This list is then returned in
+ // the "need" list in the response
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject ActionObj = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash FileHash = Field.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ NeedList.push_back(FileHash);
+ }
+ });
+
+ if (NeedList.empty())
+ {
+ // We already have everything, enqueue the action for execution
+
+ if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("action accepted (lsn {})", Result.Lsn);
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ // Could not resolve?
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+ CbObject Response = Cbo.Save();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
+ }
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Action = HttpReq.ReadPayloadPackage();
+ CbObject ActionObj = Action.GetObject();
+
+ std::span<const CbAttachment> Attachments = Action.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
+
+ TotalAttachmentBytes += CompressedSize;
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += CompressedSize;
+ ++NewAttachmentCount;
+ }
+ }
+
+ if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)",
+ Result.Lsn,
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ // Could not resolve?
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+ }
+ return;
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "workers/all",
+ [this](HttpRouterRequest& Req) { HandleWorkersAllGet(Req.ServerRequest()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/workers",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0)
+ return;
+ HandleWorkersGet(HttpReq);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/workers/all",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0)
+ return;
+ HandleWorkersAllGet(HttpReq);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/workers/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0)
+ return;
+ HandleWorkerRequest(HttpReq, IoHash::FromHexString(Req.GetCapture(2)));
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "sysinfo",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query());
+
+ CbObjectWriter Cbo;
+ Describe(Sm, Cbo);
+
+ Cbo << "cpu_usage" << Sm.CpuUsagePercent;
+ Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024;
+ Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024;
+ Cbo << "disk_used" << 100 * 1024;
+ Cbo << "disk_total" << 100 * 1024 * 1024;
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "record/start",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (!HttpReq.IsLocalMachineRequest())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording");
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "record/stop",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (!HttpReq.IsLocalMachineRequest())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ m_ComputeService.StopRecording();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK);
+ },
+ HttpVerb::kPost);
+
+ // Local-only queue listing and creation
+
+ m_Router.RegisterRoute(
+ "queues",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (!HttpReq.IsLocalMachineRequest())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("queues"sv);
+
+ for (const int QueueId : m_ComputeService.GetQueueIds())
+ {
+ ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ if (!Status.IsValid)
+ {
+ continue;
+ }
+
+ Cbo.BeginObject();
+ WriteQueueDescription(Cbo, QueueId, Status);
+ Cbo.EndObject();
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ case HttpVerb::kPost:
+ {
+ CbObject Metadata;
+ CbObject Config;
+ if (const CbObject Body = HttpReq.ReadPayloadObject())
+ {
+ Metadata = Body.Find("metadata"sv).AsObject();
+ Config = Body.Find("config"sv).AsObject();
+ }
+
+ ComputeServiceSession::CreateQueueResult Result =
+ m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config));
+
+ CbObjectWriter Cbo;
+ Cbo << "queue_id"sv << Result.QueueId;
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ // Queue creation routes — these remain separate since local creates a plain queue
+ // while remote additionally generates an OID token for external access.
+
+ m_Router.RegisterRoute(
+ "queues/remote",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ // Extract optional fields from the request body.
+ // idempotency_key: when present, we return the existing remote queue token for this
+ // key rather than creating a new queue, making the endpoint safe to call concurrently.
+ // hostname: human-readable origin context stored alongside the queue for diagnostics.
+ // metadata: arbitrary CbObject metadata propagated from the originating queue.
+ // config: arbitrary CbObject config propagated from the originating queue.
+ std::string IdempotencyKey;
+ std::string ClientHostname;
+ CbObject Metadata;
+ CbObject Config;
+ if (const CbObject Body = HttpReq.ReadPayloadObject())
+ {
+ IdempotencyKey = std::string(Body["idempotency_key"sv].AsString());
+ ClientHostname = std::string(Body["hostname"sv].AsString());
+ Metadata = Body.Find("metadata"sv).AsObject();
+ Config = Body.Find("config"sv).AsObject();
+ }
+
+ // Stamp the forwarding node's hostname into the metadata so that the
+ // remote side knows which node originated the queue.
+ if (!ClientHostname.empty())
+ {
+ CbObjectWriter MetaWriter;
+ for (auto Field : Metadata)
+ {
+ MetaWriter.AddField(Field.GetName(), Field);
+ }
+ MetaWriter << "via"sv << ClientHostname;
+ Metadata = MetaWriter.Save();
+ }
+
+ RwLock::ExclusiveLockScope _(m_RemoteQueueLock);
+
+ if (!IdempotencyKey.empty())
+ {
+ if (auto It = m_RemoteQueuesByTag.find(IdempotencyKey); It != m_RemoteQueuesByTag.end())
+ {
+ Ref<RemoteQueueInfo> Existing = It->second;
+ if (m_ComputeService.GetQueueStatus(Existing->QueueId).IsValid)
+ {
+ CbObjectWriter Cbo;
+ Cbo << "queue_token"sv << Existing->Token.ToString();
+ Cbo << "queue_id"sv << Existing->QueueId;
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ // Queue has since expired — clean up stale entries and fall through to create a new one
+ m_RemoteQueuesByToken.erase(Existing->Token);
+ m_RemoteQueuesByQueueId.erase(Existing->QueueId);
+ m_RemoteQueuesByTag.erase(It);
+ }
+ }
+
+ ComputeServiceSession::CreateQueueResult Result = m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config));
+ Ref<RemoteQueueInfo> InfoRef(new RemoteQueueInfo());
+ InfoRef->QueueId = Result.QueueId;
+ InfoRef->Token = Oid::NewOid();
+ InfoRef->IdempotencyKey = std::move(IdempotencyKey);
+ InfoRef->ClientHostname = std::move(ClientHostname);
+
+ m_RemoteQueuesByToken[InfoRef->Token] = InfoRef;
+ m_RemoteQueuesByQueueId[InfoRef->QueueId] = InfoRef;
+ if (!InfoRef->IdempotencyKey.empty())
+ {
+ m_RemoteQueuesByTag[InfoRef->IdempotencyKey] = InfoRef;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "queue_token"sv << InfoRef->Token.ToString();
+ Cbo << "queue_id"sv << InfoRef->QueueId;
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens.
+ // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution.
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ if (!Status.IsValid)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter Cbo;
+ WriteQueueDescription(Cbo, QueueId, Status);
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ case HttpVerb::kDelete:
+ {
+ ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ if (!Status.IsValid)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ m_ComputeService.CancelQueue(QueueId);
+
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/drain",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ if (!Status.IsValid)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ m_ComputeService.DrainQueue(QueueId);
+
+ // Return updated queue status
+ Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ CbObjectWriter Cbo;
+ WriteQueueDescription(Cbo, QueueId, Status);
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/completed",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ if (!Status.IsValid)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter Cbo;
+ m_ComputeService.GetQueueCompleted(QueueId, Cbo);
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/history",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId);
+
+ if (!Status.IsValid)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const auto QueryParams = HttpReq.GetQueryParams();
+
+ int QueryLimit = 50;
+
+ if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false)
+ {
+ QueryLimit = ParseInt<int>(LimitParam).value_or(50);
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("history");
+ for (const auto& Entry : m_ComputeService.GetQueueHistory(QueueId, QueryLimit))
+ {
+ Cbo.BeginObject();
+ Cbo << "lsn"sv << Entry.Lsn;
+ Cbo << "queueId"sv << Entry.QueueId;
+ Cbo << "actionId"sv << Entry.ActionId;
+ Cbo << "workerId"sv << Entry.WorkerId;
+ Cbo << "succeeded"sv << Entry.Succeeded;
+ if (Entry.CpuSeconds > 0.0f)
+ {
+ Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds);
+ }
+ if (Entry.RetryCount > 0)
+ {
+ Cbo << "retry_count"sv << Entry.RetryCount;
+ }
+
+ for (const auto& Timestamp : Entry.Timestamps)
+ {
+ Cbo.AddInteger(
+ fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))),
+ Timestamp);
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/running",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+ if (QueueId == 0)
+ {
+ return;
+ }
+ // Filter global running list to this queue
+ auto AllRunning = m_ComputeService.GetRunningActions();
+ std::vector<ComputeServiceSession::RunningActionInfo> Running;
+ for (auto& Info : AllRunning)
+ if (Info.QueueId == QueueId)
+ Running.push_back(Info);
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("running");
+ for (const auto& Info : Running)
+ {
+ Cbo.BeginObject();
+ Cbo << "lsn"sv << Info.Lsn;
+ Cbo << "queueId"sv << Info.QueueId;
+ Cbo << "actionId"sv << Info.ActionId;
+ if (Info.CpuUsagePercent >= 0.0f)
+ {
+ Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent);
+ }
+ if (Info.CpuSeconds > 0.0f)
+ {
+ Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds);
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/jobs/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(2));
+ WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId);
+
+ if (!Worker)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const auto QueryParams = Req.ServerRequest().GetQueryParams();
+ int RequestPriority = -1;
+
+ if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false)
+ {
+ RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
+ }
+
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject ActionObj = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ if (!CheckAttachments(ActionObj, NeedList))
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
+ }
+
+ if (ComputeServiceSession::EnqueueResult Result =
+ m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn);
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+ }
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Action = HttpReq.ReadPayloadPackage();
+ CbObject ActionObj = Action.GetObject();
+
+ IngestStats Stats = IngestPackageAttachments(Action);
+
+ if (ComputeServiceSession::EnqueueResult Result =
+ m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
+ QueueId,
+ ActionObj.GetHash(),
+ Result.Lsn,
+ zen::NiceBytes(Stats.Bytes),
+ Stats.Count,
+ zen::NiceBytes(Stats.NewBytes),
+ Stats.NewCount);
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+ }
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/jobs",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ const auto QueryParams = Req.ServerRequest().GetQueryParams();
+ int RequestPriority = -1;
+
+ if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false)
+ {
+ RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
+ }
+
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject ActionObj = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ if (!CheckAttachments(ActionObj, NeedList))
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
+ }
+
+ if (ComputeServiceSession::EnqueueResult Result =
+ m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn);
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+ }
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Action = HttpReq.ReadPayloadPackage();
+ CbObject ActionObj = Action.GetObject();
+
+ IngestStats Stats = IngestPackageAttachments(Action);
+
+ if (ComputeServiceSession::EnqueueResult Result =
+ m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority))
+ {
+ ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)",
+ QueueId,
+ Result.Lsn,
+ zen::NiceBytes(Stats.Bytes),
+ Stats.Count,
+ zen::NiceBytes(Stats.NewBytes),
+ Stats.NewCount);
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+ }
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/jobs/{lsn}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+ const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0);
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ ZEN_UNUSED(QueueId);
+
+ CbPackage Output;
+ HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output);
+
+ if (ResponseCode == HttpResponseCode::OK)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ else
+ {
+ HttpReq.WriteResponse(ResponseCode);
+ }
+
+ m_ComputeService.RetireActionResult(ActionLsn);
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ ZEN_UNUSED(QueueId);
+
+ auto Result = m_ComputeService.RescheduleAction(ActionLsn);
+
+ CbObjectWriter Cbo;
+ if (Result.Success)
+ {
+ Cbo << "lsn"sv << ActionLsn;
+ Cbo << "retry_count"sv << Result.RetryCount;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ else
+ {
+ Cbo << "error"sv << Result.Error;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpComputeService::HttpComputeService(CidStore& InCidStore,
+ IHttpStatsService& StatsService,
+ const std::filesystem::path& BaseDir,
+ int32_t MaxConcurrentActions)
+: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions))
+{
+}
+
+HttpComputeService::~HttpComputeService()
+{
+ m_Impl->m_StatsService.UnregisterHandler("compute", *this);
+}
+
+void
+HttpComputeService::Shutdown()
+{
+ m_Impl->m_ComputeService.Shutdown();
+}
+
+ComputeServiceSession::ActionCounts
+HttpComputeService::GetActionCounts()
+{
+ return m_Impl->m_ComputeService.GetActionCounts();
+}
+
+const char*
+HttpComputeService::BaseUri() const
+{
+ return "/compute/";
+}
+
+void
+HttpComputeService::HandleRequest(HttpServerRequest& Request)
+{
+ ZEN_TRACE_CPU("HttpComputeService::HandleRequest");
+ metrics::OperationTiming::Scope $(m_Impl->m_HttpRequests);
+
+ if (m_Impl->m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+void
+HttpComputeService::HandleStatsRequest(HttpServerRequest& Request)
+{
+ CbObjectWriter Cbo;
+ m_Impl->m_ComputeService.EmitStats(Cbo);
+
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+HttpComputeService::Impl::WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status)
+{
+ Cbo << "queue_id"sv << Status.QueueId;
+ Cbo << "active_count"sv << Status.ActiveCount;
+ Cbo << "completed_count"sv << Status.CompletedCount;
+ Cbo << "failed_count"sv << Status.FailedCount;
+ Cbo << "abandoned_count"sv << Status.AbandonedCount;
+ Cbo << "cancelled_count"sv << Status.CancelledCount;
+ Cbo << "state"sv << ToString(Status.State);
+ Cbo << "cancelled"sv << (Status.State == ComputeServiceSession::QueueState::Cancelled);
+ Cbo << "draining"sv << (Status.State == ComputeServiceSession::QueueState::Draining);
+ Cbo << "is_complete"sv << Status.IsComplete;
+
+ if (CbObject Meta = m_ComputeService.GetQueueMetadata(QueueId))
+ {
+ Cbo << "metadata"sv << Meta;
+ }
+
+ if (CbObject Cfg = m_ComputeService.GetQueueConfig(QueueId))
+ {
+ Cbo << "config"sv << Cfg;
+ }
+
+ {
+ RwLock::SharedLockScope $(m_RemoteQueueLock);
+ if (auto It = m_RemoteQueuesByQueueId.find(QueueId); It != m_RemoteQueuesByQueueId.end())
+ {
+ Cbo << "queue_token"sv << It->second->Token.ToString();
+ if (!It->second->ClientHostname.empty())
+ {
+ Cbo << "hostname"sv << It->second->ClientHostname;
+ }
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+int
+HttpComputeService::Impl::ResolveQueueToken(const Oid& Token)
+{
+ RwLock::SharedLockScope $(m_RemoteQueueLock);
+
+ auto It = m_RemoteQueuesByToken.find(Token);
+
+ if (It != m_RemoteQueuesByToken.end())
+ {
+ return It->second->QueueId;
+ }
+
+ return 0;
+}
+
+int
+HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture)
+{
+ if (OidMatcher(Capture))
+ {
+ // Remote OID token — accessible from any client
+ const Oid Token = Oid::FromHexString(Capture);
+ const int QueueId = ResolveQueueToken(Token);
+
+ if (QueueId == 0)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ return QueueId;
+ }
+
+ // Local integer queue ID — restricted to local machine requests
+ if (!HttpReq.IsLocalMachineRequest())
+ {
+ HttpReq.WriteResponse(HttpResponseCode::Forbidden);
+ return 0;
+ }
+
+ return ParseInt<int>(Capture).value_or(0);
+}
+
+HttpComputeService::Impl::IngestStats
+HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package)
+{
+ IngestStats Stats;
+
+ for (const CbAttachment& Attachment : Package.GetAttachments())
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
+
+ Stats.Bytes += CompressedSize;
+ ++Stats.Count;
+
+ const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ Stats.NewBytes += CompressedSize;
+ ++Stats.NewCount;
+ }
+ }
+
+ return Stats;
+}
+
+bool
+HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList)
+{
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash FileHash = Field.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ NeedList.push_back(FileHash);
+ }
+ });
+
+ return NeedList.empty();
+}
+
+void
+HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq)
+{
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("workers"sv);
+ for (const IoHash& WorkerId : m_ComputeService.GetKnownWorkerIds())
+ {
+ Cbo << WorkerId;
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+void
+HttpComputeService::Impl::HandleWorkersAllGet(HttpServerRequest& HttpReq)
+{
+ std::vector<IoHash> WorkerIds = m_ComputeService.GetKnownWorkerIds();
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("workers");
+
+ for (const IoHash& WorkerId : WorkerIds)
+ {
+ Cbo.BeginObject();
+ Cbo << "id" << WorkerId;
+ Cbo << "descriptor" << m_ComputeService.GetWorkerDescriptor(WorkerId).Descriptor.GetObject();
+ Cbo.EndObject();
+ }
+
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+void
+HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId)
+{
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ if (WorkerDesc Desc = m_ComputeService.GetWorkerDescriptor(WorkerId))
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject());
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+
+ case HttpVerb::kPost:
+ {
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ CbObject WorkerSpec = HttpReq.ReadPayloadObject();
+
+ HashKeySet ChunkSet;
+ WorkerSpec.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Hash = Field.AsHash();
+ ChunkSet.AddHashToSet(Hash);
+ });
+
+ CbPackage WorkerPackage;
+ WorkerPackage.SetObject(WorkerSpec);
+
+ m_CidStore.FilterChunks(ChunkSet);
+
+ if (ChunkSet.IsEmpty())
+ {
+ ZEN_DEBUG("worker {}: all attachments already available", WorkerId);
+ m_ComputeService.RegisterWorker(WorkerPackage);
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("need");
+ ChunkSet.IterateHashes([&](const IoHash& Hash) {
+ ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash);
+ ResponseWriter.AddHash(Hash);
+ });
+ ResponseWriter.EndArray();
+
+ ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize());
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save());
+ }
+ break;
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage();
+ CbObject WorkerSpec = WorkerSpecPackage.GetObject();
+
+ std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+ TotalAttachmentBytes += Buffer.GetCompressedSize();
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += Buffer.GetCompressedSize();
+ ++NewAttachmentCount;
+ }
+ }
+
+ ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments",
+ WorkerId,
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ m_ComputeService.RegisterWorker(WorkerSpecPackage);
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+httpcomputeservice_forcelink()
+{
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
new file mode 100644
index 000000000..6cbe01e04
--- /dev/null
+++ b/src/zencompute/httporchestrator.cpp
@@ -0,0 +1,650 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencompute/httporchestrator.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencompute/orchestratorservice.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/logging.h>
+# include <zencore/string.h>
+# include <zencore/system.h>
+
+namespace zen::compute {
+
+// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes
+static bool
+IsValidWorkerId(std::string_view Id)
+{
+ if (Id.size() < 3 || Id.size() > 64)
+ {
+ return false;
+ }
+ for (char c : Id)
+ {
+ if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-')
+ {
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+// Shared announce payload parser used by both the HTTP POST route and the
+// WebSocket message handler. Returns the worker ID on success (empty on
+// validation failure). The returned WorkerAnnouncement has string_view
+// fields that reference the supplied CbObjectView, so the CbObject must
+// outlive the returned announcement.
+static std::string_view
+ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann)
+{
+ Ann.Id = Data["id"].AsString("");
+ Ann.Uri = Data["uri"].AsString("");
+
+ if (!IsValidWorkerId(Ann.Id))
+ {
+ return {};
+ }
+
+ if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://"))
+ {
+ return {};
+ }
+
+ Ann.Hostname = Data["hostname"].AsString("");
+ Ann.Platform = Data["platform"].AsString("");
+ Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f);
+ Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0);
+ Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0);
+ Ann.BytesReceived = Data["bytes_received"].AsUInt64(0);
+ Ann.BytesSent = Data["bytes_sent"].AsUInt64(0);
+ Ann.ActionsPending = Data["actions_pending"].AsInt32(0);
+ Ann.ActionsRunning = Data["actions_running"].AsInt32(0);
+ Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0);
+ Ann.ActiveQueues = Data["active_queues"].AsInt32(0);
+ Ann.Provisioner = Data["provisioner"].AsString("");
+
+ if (auto Metrics = Data["metrics"].AsObjectView())
+ {
+ Ann.Cpus = Metrics["lp_count"].AsInt32(0);
+ if (Ann.Cpus <= 0)
+ {
+ Ann.Cpus = 1;
+ }
+ }
+
+ return Ann.Id;
+}
+
+HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
+: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket))
+, m_Hostname(GetMachineName())
+{
+ m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
+ m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
+
+ // dummy endpoint for websocket clients
+ m_Router.RegisterRoute(
+ "ws",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "status",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Cbo;
+ Cbo << "hostname" << std::string_view(m_Hostname);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "provision",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "announce",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+
+ OrchestratorService::WorkerAnnouncement Ann;
+ std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann);
+
+ if (WorkerId.empty())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash "
+ "characters and uri must start with http:// or https://");
+ }
+
+ m_Service->AnnounceWorker(Ann);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK);
+
+# if ZEN_WITH_WEBSOCKETS
+ // Notify push thread that state may have changed
+ m_PushEvent.Set();
+# endif
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "agents",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "history",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Params = HttpReq.GetQueryParams();
+
+ int Limit = 100;
+ auto LimitStr = Params.GetValue("limit");
+ if (!LimitStr.empty())
+ {
+ Limit = std::atoi(std::string(LimitStr).c_str());
+ }
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit));
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "timeline/{workerid}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ std::string_view WorkerId = Req.GetCapture(1);
+ auto Params = HttpReq.GetQueryParams();
+
+ auto FromStr = Params.GetValue("from");
+ auto ToStr = Params.GetValue("to");
+ auto LimitStr = Params.GetValue("limit");
+
+ std::optional<DateTime> From;
+ std::optional<DateTime> To;
+
+ if (!FromStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(FromStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ From = DateTime(*Val);
+ }
+
+ if (!ToStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(ToStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ To = DateTime(*Val);
+ }
+
+ int Limit = !LimitStr.empty() ? zen::ParseInt<int>(LimitStr).value_or(0) : 0;
+
+ CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit);
+
+ if (!Result)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result));
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "timeline",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Params = HttpReq.GetQueryParams();
+
+ auto FromStr = Params.GetValue("from");
+ auto ToStr = Params.GetValue("to");
+
+ DateTime From = DateTime(0);
+ DateTime To = DateTime::Now();
+
+ if (!FromStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(FromStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ From = DateTime(*Val);
+ }
+
+ if (!ToStr.empty())
+ {
+ auto Val = zen::ParseInt<uint64_t>(ToStr);
+ if (!Val)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ To = DateTime(*Val);
+ }
+
+ CbObject Result = m_Service->GetAllTimelines(From, To);
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result));
+ },
+ HttpVerb::kGet);
+
+ // Client tracking endpoints
+
+ m_Router.RegisterRoute(
+ "clients",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+
+ OrchestratorService::ClientAnnouncement Ann;
+ Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero);
+ Ann.Hostname = Data["hostname"].AsString("");
+ Ann.Address = HttpReq.GetRemoteAddress();
+
+ auto MetadataView = Data["metadata"].AsObjectView();
+ if (MetadataView)
+ {
+ Ann.Metadata = CbObject::Clone(MetadataView);
+ }
+
+ std::string ClientId = m_Service->AnnounceClient(Ann);
+
+ CbObjectWriter ResponseObj;
+ ResponseObj << "id" << std::string_view(ClientId);
+ HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save());
+
+# if ZEN_WITH_WEBSOCKETS
+ m_PushEvent.Set();
+# endif
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "clients/{clientid}/update",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view ClientId = Req.GetCapture(1);
+
+ CbObject MetadataObj;
+ CbObject Data = HttpReq.ReadPayloadObject();
+ if (Data)
+ {
+ auto MetadataView = Data["metadata"].AsObjectView();
+ if (MetadataView)
+ {
+ MetadataObj = CbObject::Clone(MetadataView);
+ }
+ }
+
+ if (m_Service->UpdateClient(ClientId, std::move(MetadataObj)))
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "clients/{clientid}/complete",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view ClientId = Req.GetCapture(1);
+
+ if (m_Service->CompleteClient(ClientId))
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+# if ZEN_WITH_WEBSOCKETS
+ m_PushEvent.Set();
+# endif
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "clients",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "clients/history",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ auto Params = HttpReq.GetQueryParams();
+
+ int Limit = 100;
+ auto LimitStr = Params.GetValue("limit");
+ if (!LimitStr.empty())
+ {
+ Limit = std::atoi(std::string(LimitStr).c_str());
+ }
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit));
+ },
+ HttpVerb::kGet);
+
+# if ZEN_WITH_WEBSOCKETS
+
+ // Start the WebSocket push thread
+ m_PushEnabled.store(true);
+ m_PushThread = std::thread([this] { PushThreadFunction(); });
+# endif
+}
+
+HttpOrchestratorService::~HttpOrchestratorService()
+{
+ Shutdown();
+}
+
+void
+HttpOrchestratorService::Shutdown()
+{
+# if ZEN_WITH_WEBSOCKETS
+ if (!m_PushEnabled.exchange(false))
+ {
+ return;
+ }
+
+ // Stop the push thread first, before touching connections. This ensures
+ // the push thread is no longer reading m_WsConnections or calling into
+ // m_Service when we start tearing things down.
+ m_PushEvent.Set();
+ if (m_PushThread.joinable())
+ {
+ m_PushThread.join();
+ }
+
+ // Clean up worker WebSocket connections — collect IDs under lock, then
+ // notify the service outside the lock to avoid lock-order inversions.
+ std::vector<std::string> WorkerIds;
+ m_WorkerWsLock.WithExclusiveLock([&] {
+ WorkerIds.reserve(m_WorkerWsMap.size());
+ for (const auto& [Conn, Id] : m_WorkerWsMap)
+ {
+ WorkerIds.push_back(Id);
+ }
+ m_WorkerWsMap.clear();
+ });
+ for (const auto& Id : WorkerIds)
+ {
+ m_Service->SetWorkerWebSocketConnected(Id, false);
+ }
+
+ // Now that the push thread is gone, release all dashboard connections.
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); });
+# endif
+}
+
+const char*
+HttpOrchestratorService::BaseUri() const
+{
+ return "/orch/";
+}
+
+void
+HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+# if ZEN_WITH_WEBSOCKETS
+void
+HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ if (!m_PushEnabled.load())
+ {
+ return;
+ }
+
+ ZEN_INFO("WebSocket client connected");
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+
+ // Wake push thread to send initial state immediately
+ m_PushEvent.Set();
+}
+
+void
+HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg)
+{
+ // Only handle binary messages from workers when the feature is enabled.
+ if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary)
+ {
+ return;
+ }
+
+ std::string WorkerId = HandleWorkerWebSocketMessage(Msg);
+ if (WorkerId.empty())
+ {
+ return;
+ }
+
+ // Check if this is a new worker WebSocket connection
+ bool IsNewWorkerWs = false;
+ m_WorkerWsLock.WithExclusiveLock([&] {
+ auto It = m_WorkerWsMap.find(&Conn);
+ if (It == m_WorkerWsMap.end())
+ {
+ m_WorkerWsMap[&Conn] = WorkerId;
+ IsNewWorkerWs = true;
+ }
+ });
+
+ if (IsNewWorkerWs)
+ {
+ m_Service->SetWorkerWebSocketConnected(WorkerId, true);
+ }
+
+ m_PushEvent.Set();
+}
+
+std::string
+HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg)
+{
+ // Workers send CbObject in native binary format over the WebSocket to
+ // avoid the lossy CbObject↔JSON round-trip.
+ CbObject Data = CbObject::MakeView(Msg.Payload.GetData());
+ if (!Data)
+ {
+ ZEN_WARN("worker WebSocket message is not a valid CbObject");
+ return {};
+ }
+
+ OrchestratorService::WorkerAnnouncement Ann;
+ std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann);
+ if (WorkerId.empty())
+ {
+ ZEN_WARN("invalid worker announcement via WebSocket");
+ return {};
+ }
+
+ m_Service->AnnounceWorker(Ann);
+ return std::string(WorkerId);
+}
+
+void
+HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn,
+ [[maybe_unused]] uint16_t Code,
+ [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_INFO("WebSocket client disconnected (code {})", Code);
+
+ // Check if this was a worker WebSocket connection; collect the ID under
+ // the worker lock, then notify the service outside the lock.
+ std::string DisconnectedWorkerId;
+ m_WorkerWsLock.WithExclusiveLock([&] {
+ auto It = m_WorkerWsMap.find(&Conn);
+ if (It != m_WorkerWsMap.end())
+ {
+ DisconnectedWorkerId = std::move(It->second);
+ m_WorkerWsMap.erase(It);
+ }
+ });
+
+ if (!DisconnectedWorkerId.empty())
+ {
+ m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false);
+ m_PushEvent.Set();
+ }
+
+ if (!m_PushEnabled.load())
+ {
+ return;
+ }
+
+ // Remove from dashboard connections
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+# endif
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Push thread
+//
+
+# if ZEN_WITH_WEBSOCKETS
+void
+HttpOrchestratorService::PushThreadFunction()
+{
+ SetCurrentThreadName("orch_ws_push");
+
+ while (m_PushEnabled.load())
+ {
+ m_PushEvent.Wait(2000);
+ m_PushEvent.Reset();
+
+ if (!m_PushEnabled.load())
+ {
+ break;
+ }
+
+ // Snapshot current connections
+ std::vector<Ref<WebSocketConnection>> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; });
+
+ if (Connections.empty())
+ {
+ continue;
+ }
+
+ // Build combined JSON with worker list, provisioning history, clients, and client history
+ CbObject WorkerList = m_Service->GetWorkerList();
+ CbObject History = m_Service->GetProvisioningHistory(50);
+ CbObject ClientList = m_Service->GetClientList();
+ CbObject ClientHistory = m_Service->GetClientHistory(50);
+
+ ExtendableStringBuilder<4096> JsonBuilder;
+ JsonBuilder.Append("{");
+ JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname));
+
+ // Emit workers array from worker list
+ ExtendableStringBuilder<2048> WorkerJson;
+ WorkerList.ToJson(WorkerJson);
+ std::string_view WorkerJsonView = WorkerJson.ToView();
+ // Strip outer braces: {"workers":[...]} -> "workers":[...]
+ if (WorkerJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append(",");
+
+ // Emit events array from history
+ ExtendableStringBuilder<2048> HistoryJson;
+ History.ToJson(HistoryJson);
+ std::string_view HistoryJsonView = HistoryJson.ToView();
+ if (HistoryJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append(",");
+
+ // Emit clients array from client list
+ ExtendableStringBuilder<2048> ClientJson;
+ ClientList.ToJson(ClientJson);
+ std::string_view ClientJsonView = ClientJson.ToView();
+ if (ClientJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append(",");
+
+ // Emit client_events array from client history
+ ExtendableStringBuilder<2048> ClientHistoryJson;
+ ClientHistory.ToJson(ClientHistoryJson);
+ std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView();
+ if (ClientHistoryJsonView.size() >= 2)
+ {
+ JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2));
+ }
+
+ JsonBuilder.Append("}");
+ std::string_view Json = JsonBuilder.ToView();
+
+ // Broadcast to all connected clients, prune closed ones
+ bool HadClosedConnections = false;
+
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Json);
+ }
+ else
+ {
+ HadClosedConnections = true;
+ }
+ }
+
+ if (HadClosedConnections)
+ {
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) {
+ return !C->IsOpen();
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+ }
+ }
+}
+# endif
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h
new file mode 100644
index 000000000..a5bc5a34d
--- /dev/null
+++ b/src/zencompute/include/zencompute/cloudmetadata.h
@@ -0,0 +1,151 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/logging.h>
+#include <zencore/thread.h>
+
+#include <atomic>
+#include <filesystem>
+#include <string>
+#include <thread>
+
+namespace zen::compute {
+
+enum class CloudProvider
+{
+ None,
+ AWS,
+ Azure,
+ GCP
+};
+
+std::string_view ToString(CloudProvider Provider);
+
+/** Snapshot of detected cloud instance properties. */
+struct CloudInstanceInfo
+{
+ CloudProvider Provider = CloudProvider::None;
+ std::string InstanceId;
+ std::string AvailabilityZone;
+ bool IsSpot = false;
+ bool IsAutoscaling = false;
+};
+
+/**
+ * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP)
+ * and monitors for impending termination signals.
+ *
+ * Detection works by querying the Instance Metadata Service (IMDS) at the
+ * well-known link-local address 169.254.169.254, which is only routable from
+ * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP);
+ * the first successful response wins.
+ *
+ * To avoid a ~200ms connect timeout penalty on every startup when running on
+ * bare-metal or non-cloud machines, failed probes write sentinel files
+ * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have
+ * a sentinel present. Delete the sentinel files to force re-detection.
+ *
+ * When a provider is detected, a background thread polls for termination
+ * signals every 5 seconds (spot interruption, autoscaling lifecycle changes,
+ * scheduled maintenance). The termination state is exposed as an atomic bool
+ * so the compute server can include it in coordinator announcements and react
+ * to imminent shutdown.
+ *
+ * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a
+ * shared RwLock; the background monitor thread acquires the exclusive lock
+ * only when writing the termination reason (a one-time transition). The
+ * termination-pending flag itself is a lock-free atomic.
+ *
+ * Usage:
+ * auto Cloud = std::make_unique<CloudMetadata>(DataDir / "cloud");
+ * if (Cloud->IsTerminationPending()) { ... }
+ * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB
+ */
+class CloudMetadata
+{
+public:
+ /** Synchronously probes cloud providers and starts the termination monitor
+ * if a provider is detected. Creates DataDir if it does not exist.
+ */
+ explicit CloudMetadata(std::filesystem::path DataDir);
+
+ /** Synchronously probes cloud providers at the given IMDS endpoint.
+ * Intended for testing — allows redirecting all IMDS queries to a local
+ * mock HTTP server instead of the real 169.254.169.254 endpoint.
+ */
+ CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint);
+
+ /** Stops the termination monitor thread and joins it. */
+ ~CloudMetadata();
+
+ CloudMetadata(const CloudMetadata&) = delete;
+ CloudMetadata& operator=(const CloudMetadata&) = delete;
+
+ CloudProvider GetProvider() const;
+ CloudInstanceInfo GetInstanceInfo() const;
+ bool IsTerminationPending() const;
+ std::string GetTerminationReason() const;
+
+ /** Writes a "cloud" sub-object into the compact binary writer if a provider
+ * was detected. No-op when running on bare metal.
+ */
+ void Describe(CbWriter& Writer) const;
+
+ /** Executes a single termination-poll cycle for the detected provider.
+ * Public so tests can drive poll cycles synchronously without relying on
+ * the background thread's 5-second timer.
+ */
+ void PollTermination();
+
+ /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure,
+ * .isNotGCP) from DataDir so subsequent detection probes are not skipped.
+ * Primarily intended for tests that need to reset state between sub-cases.
+ */
+ void ClearSentinelFiles();
+
+private:
+ /** Tries each provider in order, stops on first successful detection. */
+ void DetectProvider();
+ bool TryDetectAWS();
+ bool TryDetectAzure();
+ bool TryDetectGCP();
+
+ void WriteSentinelFile(const std::filesystem::path& Path);
+ bool HasSentinelFile(const std::filesystem::path& Path) const;
+
+ void StartTerminationMonitor();
+ void TerminationMonitorThread();
+ void PollAWSTermination();
+ void PollAzureTermination();
+ void PollGCPTermination();
+
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ std::filesystem::path m_DataDir;
+ std::string m_ImdsEndpoint;
+
+ mutable RwLock m_InfoLock;
+ CloudInstanceInfo m_Info;
+
+ std::atomic<bool> m_TerminationPending{false};
+
+ mutable RwLock m_ReasonLock;
+ std::string m_TerminationReason;
+
+ // IMDSv2 session token, acquired during AWS detection and reused for
+ // subsequent termination polling. Has a 300s TTL on the AWS side; if it
+ // expires mid-run the poll requests will get 401s which we treat as
+ // non-terminal (the monitor simply retries next cycle).
+ std::string m_AwsToken;
+
+ std::thread m_MonitorThread;
+ std::atomic<bool> m_MonitorEnabled{true};
+ Event m_MonitorEvent;
+};
+
+void cloudmetadata_forcelink(); // internal
+
+} // namespace zen::compute
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
new file mode 100644
index 000000000..65ec5f9ee
--- /dev/null
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -0,0 +1,262 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/zencompute.h>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/iohash.h>
+# include <zenstore/zenstore.h>
+# include <zenhttp/httpcommon.h>
+
+# include <filesystem>
+
+namespace zen {
+class ChunkResolver;
+class CbObjectWriter;
+} // namespace zen
+
+namespace zen::compute {
+
+class ActionRecorder;
+class ComputeServiceSession;
+class IActionResultHandler;
+class LocalProcessRunner;
+class RemoteHttpRunner;
+struct RunnerAction;
+struct SubmitResult;
+
+struct WorkerDesc
+{
+ CbPackage Descriptor;
+ IoHash WorkerId{IoHash::Zero};
+
+ inline operator bool() const { return WorkerId != IoHash::Zero; }
+};
+
+/**
+ * Lambda style compute function service
+ *
+ * The responsibility of this class is to accept function execution requests, and
+ * schedule them using one or more FunctionRunner instances. It will basically always
+ * accept requests, queueing them if necessary, and then hand them off to runners
+ * as they become available.
+ *
+ * This is typically fronted by an API service that handles communication with clients.
+ */
+class ComputeServiceSession final
+{
+public:
+ /**
+ * Session lifecycle state machine.
+ *
+ * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset
+ * Backward transitions: Draining -> Ready, Paused -> Ready
+ * Automatic transition: Draining -> Paused (when pending + running reaches 0)
+ * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset
+ * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out)
+ *
+ * | State | Accept new actions | Schedule pending | Finish running |
+ * |-----------|-------------------|-----------------|----------------|
+ * | Created | No | No | N/A |
+ * | Ready | Yes | Yes | Yes |
+ * | Draining | No | Yes | Yes |
+ * | Paused | No | No | No |
+ * | Abandoned | No | No | No (all abandoned) |
+ * | Sunset | No | No | No |
+ */
+ enum class SessionState
+ {
+ Created, // Initial state before WaitUntilReady completes
+ Ready, // Normal operating state; accepts and schedules work
+ Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty
+ Paused, // Idle; no work accepted or scheduled; can resume to Ready
+ Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out
+ Sunset // Terminal; triggers full shutdown
+ };
+
+ ComputeServiceSession(ChunkResolver& InChunkResolver);
+ ~ComputeServiceSession();
+
+ void WaitUntilReady();
+ void Shutdown();
+ bool IsHealthy();
+
+ SessionState GetSessionState() const;
+
+ // Request a state transition. Returns false if the transition is invalid.
+ // Sunset can be reached from any non-Sunset state.
+ bool RequestStateTransition(SessionState NewState);
+
+ // Orchestration
+
+ void SetOrchestratorEndpoint(std::string_view Endpoint);
+ void SetOrchestratorBasePath(std::filesystem::path BasePath);
+
+ // Worker registration and discovery
+
+ void RegisterWorker(CbPackage Worker);
+ [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId);
+ [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds();
+
+ // Action runners
+
+ void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
+ void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName);
+
+ // Action submission
+
+ struct EnqueueResult
+ {
+ int Lsn;
+ CbObject ResponseMessage;
+
+ inline operator bool() const { return Lsn != 0; }
+ };
+
+ [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority);
+ [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority);
+
+ // Queue management
+ //
+ // Queues group actions submitted by a single client session. They allow
+ // cancelling or polling completion of all actions in the group.
+
+ struct CreateQueueResult
+ {
+ int QueueId = 0; // 0 if creation failed
+ };
+
+ enum class QueueState
+ {
+ Active,
+ Draining,
+ Cancelled,
+ };
+
+ struct QueueStatus
+ {
+ bool IsValid = false;
+ int QueueId = 0;
+ int ActiveCount = 0; // pending + running (not yet completed)
+ int CompletedCount = 0; // successfully completed
+ int FailedCount = 0; // failed
+ int AbandonedCount = 0; // abandoned
+ int CancelledCount = 0; // cancelled
+ QueueState State = QueueState::Active;
+ bool IsComplete = false; // ActiveCount == 0
+ };
+
+ [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {});
+ [[nodiscard]] std::vector<int> GetQueueIds();
+ [[nodiscard]] QueueStatus GetQueueStatus(int QueueId);
+ [[nodiscard]] CbObject GetQueueMetadata(int QueueId);
+ [[nodiscard]] CbObject GetQueueConfig(int QueueId);
+ void CancelQueue(int QueueId);
+ void DrainQueue(int QueueId);
+ void DeleteQueue(int QueueId);
+ void GetQueueCompleted(int QueueId, CbWriter& Cbo);
+
+ // Queue-scoped action submission. Actions submitted via these methods are
+ // tracked under the given queue in addition to the global LSN-based tracking.
+
+ [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority);
+ [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority);
+
+ // Completed action tracking
+
+ [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage);
+ [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage);
+ void RetireActionResult(int ActionLsn);
+
+ // Action rescheduling
+
+ struct RescheduleResult
+ {
+ bool Success = false;
+ std::string Error;
+ int RetryCount = 0;
+ };
+
+ [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn);
+
+ void GetCompleted(CbWriter&);
+
+ // Running action tracking
+
+ struct RunningActionInfo
+ {
+ int Lsn;
+ int QueueId;
+ IoHash ActionId;
+ float CpuUsagePercent; // -1.0 if not yet sampled
+ float CpuSeconds; // 0.0 if not yet sampled
+ };
+
+ [[nodiscard]] std::vector<RunningActionInfo> GetRunningActions();
+
+ // Action history tracking (note that this is separate from completed action tracking, and
+ // will include actions which have been retired and no longer have their results available)
+
+ struct ActionHistoryEntry
+ {
+ int Lsn;
+ int QueueId = 0;
+ IoHash ActionId;
+ IoHash WorkerId;
+ CbObject ActionDescriptor;
+ std::string ExecutionLocation;
+ bool Succeeded;
+ float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled
+ int RetryCount = 0; // number of times this action was rescheduled
+ // sized to match RunnerAction::State::_Count but we can't use the enum here
+ // for dependency reasons, so just use a fixed size array and static assert in
+ // the implementation file
+ uint64_t Timestamps[8] = {};
+ };
+
+ [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100);
+ [[nodiscard]] std::vector<ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit = 100);
+
+ // Stats reporting
+
+ struct ActionCounts
+ {
+ int Pending = 0;
+ int Running = 0;
+ int Completed = 0;
+ int ActiveQueues = 0;
+ };
+
+ [[nodiscard]] ActionCounts GetActionCounts();
+
+ void EmitStats(CbObjectWriter& Cbo);
+
+ // Recording
+
+ void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
+ void StopRecording();
+
+private:
+ void PostUpdate(RunnerAction* Action);
+
+ friend class FunctionRunner;
+ friend struct RunnerAction;
+
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+void computeservice_forcelink();
+
+} // namespace zen::compute
+
+namespace zen {
+const char* ToString(compute::ComputeServiceSession::SessionState State);
+const char* ToString(compute::ComputeServiceSession::QueueState State);
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
new file mode 100644
index 000000000..ee1cd2614
--- /dev/null
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/zencompute.h>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "zencompute/computeservice.h"
+
+# include <zenhttp/httpserver.h>
+
+# include <filesystem>
+# include <memory>
+
+namespace zen {
+class CidStore;
+}
+
+namespace zen::compute {
+
+/**
+ * HTTP interface for compute service
+ */
+class HttpComputeService : public HttpService, public IHttpStatsProvider
+{
+public:
+ HttpComputeService(CidStore& InCidStore,
+ IHttpStatsService& StatsService,
+ const std::filesystem::path& BaseDir,
+ int32_t MaxConcurrentActions = 0);
+ ~HttpComputeService();
+
+ void Shutdown();
+
+ [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts();
+
+ const char* BaseUri() const override;
+ void HandleRequest(HttpServerRequest& Request) override;
+
+ // IHttpStatsProvider
+
+ void HandleStatsRequest(HttpServerRequest& Request) override;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+void httpcomputeservice_forcelink();
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h
new file mode 100644
index 000000000..da5c5dfc3
--- /dev/null
+++ b/src/zencompute/include/zencompute/httporchestrator.h
@@ -0,0 +1,101 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/zencompute.h>
+
+#include <zencore/logging.h>
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+
+#include <atomic>
+#include <filesystem>
+#include <memory>
+#include <string>
+#include <thread>
+#include <unordered_map>
+#include <vector>
+
+#define ZEN_WITH_WEBSOCKETS 1
+
+namespace zen::compute {
+
+class OrchestratorService;
+
+// Experimental helper, to see if we can get rid of some error-prone
+// boilerplate when declaring loggers as class members.
+
+class LoggerHelper
+{
+public:
+ LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {}
+
+ LoggerRef operator()() { return m_Log; }
+
+private:
+ LoggerRef m_Log;
+};
+
+/**
+ * Orchestrator HTTP service with WebSocket push support
+ *
+ * Normal HTTP requests are routed through the HttpRequestRouter as before.
+ * WebSocket clients connecting to /orch/ws receive periodic state broadcasts
+ * from a dedicated push thread, eliminating the need for polling.
+ */
+
+class HttpOrchestratorService : public HttpService
+#if ZEN_WITH_WEBSOCKETS
+,
+ public IWebSocketHandler
+#endif
+{
+public:
+ explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false);
+ ~HttpOrchestratorService();
+
+ HttpOrchestratorService(const HttpOrchestratorService&) = delete;
+ HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete;
+
+ /**
+ * Gracefully shut down the WebSocket push thread and release connections.
+ * Must be called while the ASIO io_context is still alive. The destructor
+ * also calls this, so it is safe (but not ideal) to omit the explicit call.
+ */
+ void Shutdown();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+ // IWebSocketHandler
+#if ZEN_WITH_WEBSOCKETS
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+#endif
+
+private:
+ HttpRequestRouter m_Router;
+ LoggerHelper Log{"orch"};
+ std::unique_ptr<OrchestratorService> m_Service;
+ std::string m_Hostname;
+
+ // WebSocket push
+
+#if ZEN_WITH_WEBSOCKETS
+ RwLock m_WsConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::thread m_PushThread;
+ std::atomic<bool> m_PushEnabled{false};
+ Event m_PushEvent;
+ void PushThreadFunction();
+
+ // Worker WebSocket connections (worker→orchestrator persistent links)
+ RwLock m_WorkerWsLock;
+ std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID
+ std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg);
+#endif
+};
+
+} // namespace zen::compute
diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h
new file mode 100644
index 000000000..521722e63
--- /dev/null
+++ b/src/zencompute/include/zencompute/mockimds.h
@@ -0,0 +1,102 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/cloudmetadata.h>
+#include <zenhttp/httpserver.h>
+
+#include <string>
+
+#if ZEN_WITH_TESTS
+
+namespace zen::compute {
+
+/**
+ * Mock IMDS (Instance Metadata Service) for testing CloudMetadata.
+ *
+ * Implements an HttpService that responds to the same URL paths as the real
+ * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata).
+ * Tests configure which provider is "active" and set the desired response
+ * values, then pass the mock server's address as the ImdsEndpoint to the
+ * CloudMetadata constructor.
+ *
+ * When a request arrives for a provider that is not the ActiveProvider, the
+ * mock returns 404, causing CloudMetadata to write a sentinel file and move
+ * on to the next provider — exactly like a failed probe on bare metal.
+ *
+ * All config fields are public and can be mutated between poll cycles to
+ * simulate state changes (e.g. a spot interruption appearing mid-run).
+ *
+ * Usage:
+ * MockImdsService Mock;
+ * Mock.ActiveProvider = CloudProvider::AWS;
+ * Mock.Aws.InstanceId = "i-test";
+ * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint
+ */
+class MockImdsService : public HttpService
+{
+public:
+ /** AWS IMDSv2 response configuration. */
+ struct AwsConfig
+ {
+ std::string Token = "mock-aws-token-v2";
+ std::string InstanceId = "i-0123456789abcdef0";
+ std::string AvailabilityZone = "us-east-1a";
+ std::string LifeCycle = "on-demand"; // "spot" or "on-demand"
+
+ // Empty string → endpoint returns 404 (instance not in an ASG).
+ // Non-empty → returned as the response body. "InService" means healthy;
+ // anything else (e.g. "Terminated:Wait") triggers termination detection.
+ std::string AutoscalingState;
+
+ // Empty string → endpoint returns 404 (no spot interruption).
+ // Non-empty → returned as the response body, signalling a spot reclaim.
+ std::string SpotAction;
+ };
+
+ /** Azure IMDS response configuration. */
+ struct AzureConfig
+ {
+ std::string VmId = "vm-12345678-1234-1234-1234-123456789abc";
+ std::string Location = "eastus";
+ std::string Priority = "Regular"; // "Spot" or "Regular"
+
+ // Empty → instance is not in a VM Scale Set (no autoscaling).
+ std::string VmScaleSetName;
+
+ // Empty → no scheduled events. Set to "Preempt", "Terminate", or
+ // "Reboot" to simulate a termination-class event.
+ std::string ScheduledEventType;
+ std::string ScheduledEventStatus = "Scheduled";
+ };
+
+ /** GCP metadata response configuration. */
+ struct GcpConfig
+ {
+ std::string InstanceId = "1234567890123456789";
+ std::string Zone = "projects/123456/zones/us-central1-a";
+ std::string Preemptible = "FALSE"; // "TRUE" or "FALSE"
+ std::string MaintenanceEvent = "NONE"; // "NONE" or event description
+ };
+
+ /** Which provider's endpoints respond successfully.
+ * Requests targeting other providers receive 404.
+ */
+ CloudProvider ActiveProvider = CloudProvider::None;
+
+ AwsConfig Aws;
+ AzureConfig Azure;
+ GcpConfig Gcp;
+
+ const char* BaseUri() const override;
+ void HandleRequest(HttpServerRequest& Request) override;
+
+private:
+ void HandleAwsRequest(HttpServerRequest& Request);
+ void HandleAzureRequest(HttpServerRequest& Request);
+ void HandleGcpRequest(HttpServerRequest& Request);
+};
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h
new file mode 100644
index 000000000..071e902b3
--- /dev/null
+++ b/src/zencompute/include/zencompute/orchestratorservice.h
@@ -0,0 +1,177 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/zencompute.h>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+# include <zencore/uid.h>
+
+# include <deque>
+# include <optional>
+# include <filesystem>
+# include <memory>
+# include <string>
+# include <string_view>
+# include <thread>
+# include <unordered_map>
+
+namespace zen::compute {
+
+class WorkerTimelineStore;
+
+class OrchestratorService
+{
+public:
+ explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false);
+ ~OrchestratorService();
+
+ OrchestratorService(const OrchestratorService&) = delete;
+ OrchestratorService& operator=(const OrchestratorService&) = delete;
+
+ struct WorkerAnnouncement
+ {
+ std::string_view Id;
+ std::string_view Uri;
+ std::string_view Hostname;
+ std::string_view Platform; // e.g. "windows", "wine", "linux", "macos"
+ int Cpus = 0;
+ float CpuUsagePercent = 0.0f;
+ uint64_t MemoryTotalBytes = 0;
+ uint64_t MemoryUsedBytes = 0;
+ uint64_t BytesReceived = 0;
+ uint64_t BytesSent = 0;
+ int ActionsPending = 0;
+ int ActionsRunning = 0;
+ int ActionsCompleted = 0;
+ int ActiveQueues = 0;
+ std::string_view Provisioner; // e.g. "horde", "nomad", or empty
+ };
+
+ struct ProvisioningEvent
+ {
+ enum class Type
+ {
+ Joined,
+ Left,
+ Returned
+ };
+ Type EventType;
+ DateTime Timestamp;
+ std::string WorkerId;
+ std::string Hostname;
+ };
+
+ struct ClientAnnouncement
+ {
+ Oid SessionId;
+ std::string_view Hostname;
+ std::string_view Address;
+ CbObject Metadata;
+ };
+
+ struct ClientEvent
+ {
+ enum class Type
+ {
+ Connected,
+ Disconnected,
+ Updated
+ };
+ Type EventType;
+ DateTime Timestamp;
+ std::string ClientId;
+ std::string Hostname;
+ };
+
+ CbObject GetWorkerList();
+ void AnnounceWorker(const WorkerAnnouncement& Announcement);
+
+ bool IsWorkerWebSocketEnabled() const;
+ void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected);
+
+ CbObject GetProvisioningHistory(int Limit = 100);
+
+ CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit);
+
+ CbObject GetAllTimelines(DateTime From, DateTime To);
+
+ std::string AnnounceClient(const ClientAnnouncement& Announcement);
+ bool UpdateClient(std::string_view ClientId, CbObject Metadata = {});
+ bool CompleteClient(std::string_view ClientId);
+ CbObject GetClientList();
+ CbObject GetClientHistory(int Limit = 100);
+
+private:
+ enum class ReachableState
+ {
+ Unknown,
+ Reachable,
+ Unreachable,
+ };
+
+ struct KnownWorker
+ {
+ std::string BaseUri;
+ Stopwatch LastSeen;
+ std::string Hostname;
+ std::string Platform;
+ int Cpus = 0;
+ float CpuUsagePercent = 0.0f;
+ uint64_t MemoryTotalBytes = 0;
+ uint64_t MemoryUsedBytes = 0;
+ uint64_t BytesReceived = 0;
+ uint64_t BytesSent = 0;
+ int ActionsPending = 0;
+ int ActionsRunning = 0;
+ int ActionsCompleted = 0;
+ int ActiveQueues = 0;
+ std::string Provisioner;
+ ReachableState Reachable = ReachableState::Unknown;
+ bool WsConnected = false;
+ Stopwatch LastProbed;
+ };
+
+ RwLock m_KnownWorkersLock;
+ std::unordered_map<std::string, KnownWorker> m_KnownWorkers;
+ std::unique_ptr<WorkerTimelineStore> m_TimelineStore;
+
+ RwLock m_ProvisioningLogLock;
+ std::deque<ProvisioningEvent> m_ProvisioningLog;
+ static constexpr size_t kMaxProvisioningEvents = 1000;
+
+ void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname);
+
+ struct KnownClient
+ {
+ Oid SessionId;
+ std::string Hostname;
+ std::string Address;
+ Stopwatch LastSeen;
+ CbObject Metadata;
+ };
+
+ RwLock m_KnownClientsLock;
+ std::unordered_map<std::string, KnownClient> m_KnownClients;
+
+ RwLock m_ClientLogLock;
+ std::deque<ClientEvent> m_ClientLog;
+ static constexpr size_t kMaxClientEvents = 1000;
+
+ void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname);
+
+ bool m_EnableWorkerWebSocket = false;
+
+ std::thread m_ProbeThread;
+ std::atomic<bool> m_ProbeThreadEnabled{true};
+ Event m_ProbeThreadEvent;
+ void ProbeThreadFunction();
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h
new file mode 100644
index 000000000..3f233fae0
--- /dev/null
+++ b/src/zencompute/include/zencompute/recordingreader.h
@@ -0,0 +1,129 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/zencompute.h>
+
+#include <zencompute/computeservice.h>
+#include <zencompute/zencompute.h>
+#include <zencore/basicfile.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/gc.h>
+#include <zenstore/zenstore.h>
+
+#include <filesystem>
+#include <functional>
+#include <unordered_map>
+
+namespace zen {
+class CbObject;
+class CbPackage;
+struct IoHash;
+} // namespace zen
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen::compute {
+
+//////////////////////////////////////////////////////////////////////////
+
+class RecordingReaderBase
+{
+ RecordingReaderBase(const RecordingReaderBase&) = delete;
+ RecordingReaderBase& operator=(const RecordingReaderBase&) = delete;
+
+public:
+ RecordingReaderBase() = default;
+ virtual ~RecordingReaderBase() = 0;
+ virtual std::unordered_map<IoHash, CbPackage> ReadWorkers() = 0;
+ virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int TargetParallelism) = 0;
+ virtual size_t GetActionCount() const = 0;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+/**
+ * Reader for recordings done via the zencompute recording system, which
+ * have a shared chunk store and a log of actions with pointers into the
+ * chunk store for their data.
+ */
+class RecordingReader : public RecordingReaderBase, public ChunkResolver
+{
+public:
+ explicit RecordingReader(const std::filesystem::path& RecordingPath);
+ ~RecordingReader();
+
+ virtual std::unordered_map<zen::IoHash, zen::CbPackage> ReadWorkers() override;
+
+ virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback,
+ int TargetParallelism) override;
+ virtual size_t GetActionCount() const override;
+
+private:
+ std::filesystem::path m_RecordingLogDir;
+ BasicFile m_WorkerDataFile;
+ BasicFile m_ActionDataFile;
+ GcManager m_Gc;
+ CidStore m_CidStore{m_Gc};
+
+ // ChunkResolver interface
+ virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
+
+ struct ActionEntry
+ {
+ IoHash ActionId;
+ uint64_t Offset;
+ uint64_t Size;
+ };
+
+ std::vector<ActionEntry> m_Actions;
+
+ void ScanActions();
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct LocalResolver : public ChunkResolver
+{
+ LocalResolver(const LocalResolver&) = delete;
+ LocalResolver& operator=(const LocalResolver&) = delete;
+
+ LocalResolver() = default;
+ ~LocalResolver() = default;
+
+ virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
+ void Add(const IoHash& Cid, IoBuffer Data);
+
+private:
+ RwLock MapLock;
+ std::unordered_map<IoHash, IoBuffer> Attachments;
+};
+
+/**
+ * This is a reader for UE/DDB recordings, which have a different layout on
+ * disk (no shared chunk store)
+ */
+class UeRecordingReader : public RecordingReaderBase, public ChunkResolver
+{
+public:
+ explicit UeRecordingReader(const std::filesystem::path& RecordingPath);
+ ~UeRecordingReader();
+
+ virtual std::unordered_map<zen::IoHash, zen::CbPackage> ReadWorkers() override;
+ virtual void IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback,
+ int TargetParallelism) override;
+ virtual size_t GetActionCount() const override;
+ virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
+
+private:
+ std::filesystem::path m_RecordingDir;
+ LocalResolver m_LocalResolver;
+ std::vector<std::filesystem::path> m_WorkDirs;
+
+ CbPackage ReadAction(std::filesystem::path WorkDir);
+};
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h
new file mode 100644
index 000000000..00be4d4a0
--- /dev/null
+++ b/src/zencompute/include/zencompute/zencompute.h
@@ -0,0 +1,15 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if !defined(ZEN_WITH_COMPUTE_SERVICES)
+# define ZEN_WITH_COMPUTE_SERVICES 1
+#endif
+
+namespace zen {
+
+void zencompute_forcelinktests();
+
+}
diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp
new file mode 100644
index 000000000..9ea695305
--- /dev/null
+++ b/src/zencompute/orchestratorservice.cpp
@@ -0,0 +1,710 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencompute/orchestratorservice.h>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/logging.h>
+# include <zencore/trace.h>
+# include <zenhttp/httpclient.h>
+
+# include "timeline/workertimeline.h"
+
+namespace zen::compute {
+
+OrchestratorService::OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
+: m_TimelineStore(std::make_unique<WorkerTimelineStore>(DataDir / "timelines"))
+, m_EnableWorkerWebSocket(EnableWorkerWebSocket)
+{
+ m_ProbeThread = std::thread{&OrchestratorService::ProbeThreadFunction, this};
+}
+
+OrchestratorService::~OrchestratorService()
+{
+ m_ProbeThreadEnabled = false;
+ m_ProbeThreadEvent.Set();
+ if (m_ProbeThread.joinable())
+ {
+ m_ProbeThread.join();
+ }
+}
+
+CbObject
+OrchestratorService::GetWorkerList()
+{
+ ZEN_TRACE_CPU("OrchestratorService::GetWorkerList");
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("workers");
+
+ m_KnownWorkersLock.WithSharedLock([&] {
+ for (const auto& [WorkerId, Worker] : m_KnownWorkers)
+ {
+ Cbo.BeginObject();
+ Cbo << "id" << WorkerId;
+ Cbo << "uri" << Worker.BaseUri;
+ Cbo << "hostname" << Worker.Hostname;
+ if (!Worker.Platform.empty())
+ {
+ Cbo << "platform" << std::string_view(Worker.Platform);
+ }
+ Cbo << "cpus" << Worker.Cpus;
+ Cbo << "cpu_usage" << Worker.CpuUsagePercent;
+ Cbo << "memory_total" << Worker.MemoryTotalBytes;
+ Cbo << "memory_used" << Worker.MemoryUsedBytes;
+ Cbo << "bytes_received" << Worker.BytesReceived;
+ Cbo << "bytes_sent" << Worker.BytesSent;
+ Cbo << "actions_pending" << Worker.ActionsPending;
+ Cbo << "actions_running" << Worker.ActionsRunning;
+ Cbo << "actions_completed" << Worker.ActionsCompleted;
+ Cbo << "active_queues" << Worker.ActiveQueues;
+ if (!Worker.Provisioner.empty())
+ {
+ Cbo << "provisioner" << std::string_view(Worker.Provisioner);
+ }
+ if (Worker.Reachable != ReachableState::Unknown)
+ {
+ Cbo << "reachable" << (Worker.Reachable == ReachableState::Reachable);
+ }
+ if (Worker.WsConnected)
+ {
+ Cbo << "ws_connected" << true;
+ }
+ Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs();
+ Cbo.EndObject();
+ }
+ });
+
+ Cbo.EndArray();
+ return Cbo.Save();
+}
+
+void
+OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann)
+{
+ ZEN_TRACE_CPU("OrchestratorService::AnnounceWorker");
+
+ bool IsNew = false;
+ std::string EvictedId;
+ std::string EvictedHostname;
+
+ m_KnownWorkersLock.WithExclusiveLock([&] {
+ IsNew = (m_KnownWorkers.find(std::string(Ann.Id)) == m_KnownWorkers.end());
+
+ // If a different worker ID already maps to the same URI, the old entry
+ // is stale (e.g. a previous Horde lease on the same machine). Remove it
+ // so the dashboard doesn't show duplicates.
+ if (IsNew)
+ {
+ for (auto It = m_KnownWorkers.begin(); It != m_KnownWorkers.end(); ++It)
+ {
+ if (It->second.BaseUri == Ann.Uri && It->first != Ann.Id)
+ {
+ EvictedId = It->first;
+ EvictedHostname = It->second.Hostname;
+ m_KnownWorkers.erase(It);
+ break;
+ }
+ }
+ }
+
+ auto& Worker = m_KnownWorkers[std::string(Ann.Id)];
+ Worker.BaseUri = Ann.Uri;
+ Worker.Hostname = Ann.Hostname;
+ if (!Ann.Platform.empty())
+ {
+ Worker.Platform = Ann.Platform;
+ }
+ Worker.Cpus = Ann.Cpus;
+ Worker.CpuUsagePercent = Ann.CpuUsagePercent;
+ Worker.MemoryTotalBytes = Ann.MemoryTotalBytes;
+ Worker.MemoryUsedBytes = Ann.MemoryUsedBytes;
+ Worker.BytesReceived = Ann.BytesReceived;
+ Worker.BytesSent = Ann.BytesSent;
+ Worker.ActionsPending = Ann.ActionsPending;
+ Worker.ActionsRunning = Ann.ActionsRunning;
+ Worker.ActionsCompleted = Ann.ActionsCompleted;
+ Worker.ActiveQueues = Ann.ActiveQueues;
+ if (!Ann.Provisioner.empty())
+ {
+ Worker.Provisioner = Ann.Provisioner;
+ }
+ Worker.LastSeen.Reset();
+ });
+
+ if (!EvictedId.empty())
+ {
+ ZEN_INFO("worker {} superseded by {} (same endpoint)", EvictedId, Ann.Id);
+ RecordProvisioningEvent(ProvisioningEvent::Type::Left, EvictedId, EvictedHostname);
+ }
+
+ if (IsNew)
+ {
+ RecordProvisioningEvent(ProvisioningEvent::Type::Joined, Ann.Id, Ann.Hostname);
+ }
+}
+
+bool
+OrchestratorService::IsWorkerWebSocketEnabled() const
+{
+ return m_EnableWorkerWebSocket;
+}
+
+void
+OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected)
+{
+ ReachableState PrevState = ReachableState::Unknown;
+ std::string WorkerHostname;
+
+ m_KnownWorkersLock.WithExclusiveLock([&] {
+ auto It = m_KnownWorkers.find(std::string(WorkerId));
+ if (It == m_KnownWorkers.end())
+ {
+ return;
+ }
+
+ PrevState = It->second.Reachable;
+ WorkerHostname = It->second.Hostname;
+ It->second.WsConnected = Connected;
+ It->second.Reachable = Connected ? ReachableState::Reachable : ReachableState::Unreachable;
+
+ if (Connected)
+ {
+ ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId);
+ }
+ else
+ {
+ ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId);
+ }
+ });
+
+ // Record provisioning events for state transitions outside the lock
+ if (Connected && PrevState == ReachableState::Unreachable)
+ {
+ RecordProvisioningEvent(ProvisioningEvent::Type::Returned, WorkerId, WorkerHostname);
+ }
+ else if (!Connected && PrevState == ReachableState::Reachable)
+ {
+ RecordProvisioningEvent(ProvisioningEvent::Type::Left, WorkerId, WorkerHostname);
+ }
+}
+
+CbObject
+OrchestratorService::GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit)
+{
+ ZEN_TRACE_CPU("OrchestratorService::GetWorkerTimeline");
+
+ Ref<WorkerTimeline> Timeline = m_TimelineStore->Find(WorkerId);
+ if (!Timeline)
+ {
+ return {};
+ }
+
+ std::vector<WorkerTimeline::Event> Events;
+
+ if (From || To)
+ {
+ DateTime StartTime = From.value_or(DateTime(0));
+ DateTime EndTime = To.value_or(DateTime::Now());
+ Events = Timeline->QueryTimeline(StartTime, EndTime);
+ }
+ else if (Limit > 0)
+ {
+ Events = Timeline->QueryRecent(Limit);
+ }
+ else
+ {
+ Events = Timeline->QueryRecent();
+ }
+
+ WorkerTimeline::TimeRange Range = Timeline->GetTimeRange();
+
+ CbObjectWriter Cbo;
+ Cbo << "worker_id" << WorkerId;
+ Cbo << "event_count" << static_cast<int32_t>(Timeline->GetEventCount());
+
+ if (Range)
+ {
+ Cbo.AddDateTime("time_first", Range.First);
+ Cbo.AddDateTime("time_last", Range.Last);
+ }
+
+ Cbo.BeginArray("events");
+ for (const auto& Evt : Events)
+ {
+ Cbo.BeginObject();
+ Cbo << "type" << WorkerTimeline::ToString(Evt.Type);
+ Cbo.AddDateTime("ts", Evt.Timestamp);
+
+ if (Evt.ActionLsn != 0)
+ {
+ Cbo << "lsn" << Evt.ActionLsn;
+ Cbo << "action_id" << Evt.ActionId;
+ }
+
+ if (Evt.Type == WorkerTimeline::EventType::ActionStateChanged)
+ {
+ Cbo << "prev_state" << RunnerAction::ToString(Evt.PreviousState);
+ Cbo << "state" << RunnerAction::ToString(Evt.ActionState);
+ }
+
+ if (!Evt.Reason.empty())
+ {
+ Cbo << "reason" << std::string_view(Evt.Reason);
+ }
+
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+
+ return Cbo.Save();
+}
+
+CbObject
+OrchestratorService::GetAllTimelines(DateTime From, DateTime To)
+{
+ ZEN_TRACE_CPU("OrchestratorService::GetAllTimelines");
+
+ DateTime StartTime = From;
+ DateTime EndTime = To;
+
+ auto AllInfo = m_TimelineStore->GetAllWorkerInfo();
+
+ CbObjectWriter Cbo;
+ Cbo.AddDateTime("from", StartTime);
+ Cbo.AddDateTime("to", EndTime);
+
+ Cbo.BeginArray("workers");
+ for (const auto& Info : AllInfo)
+ {
+ if (!Info.Range || Info.Range.Last < StartTime || Info.Range.First > EndTime)
+ {
+ continue;
+ }
+
+ Cbo.BeginObject();
+ Cbo << "worker_id" << Info.WorkerId;
+ Cbo.AddDateTime("time_first", Info.Range.First);
+ Cbo.AddDateTime("time_last", Info.Range.Last);
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+
+ return Cbo.Save();
+}
+
+void
+OrchestratorService::RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname)
+{
+ ProvisioningEvent Evt{
+ .EventType = Type,
+ .Timestamp = DateTime::Now(),
+ .WorkerId = std::string(WorkerId),
+ .Hostname = std::string(Hostname),
+ };
+
+ m_ProvisioningLogLock.WithExclusiveLock([&] {
+ m_ProvisioningLog.push_back(std::move(Evt));
+ while (m_ProvisioningLog.size() > kMaxProvisioningEvents)
+ {
+ m_ProvisioningLog.pop_front();
+ }
+ });
+}
+
+CbObject
+OrchestratorService::GetProvisioningHistory(int Limit)
+{
+ ZEN_TRACE_CPU("OrchestratorService::GetProvisioningHistory");
+
+ if (Limit <= 0)
+ {
+ Limit = 100;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("events");
+
+ m_ProvisioningLogLock.WithSharedLock([&] {
+ // Return last N events, newest first
+ int Count = 0;
+ for (auto It = m_ProvisioningLog.rbegin(); It != m_ProvisioningLog.rend() && Count < Limit; ++It, ++Count)
+ {
+ const auto& Evt = *It;
+ Cbo.BeginObject();
+
+ switch (Evt.EventType)
+ {
+ case ProvisioningEvent::Type::Joined:
+ Cbo << "type"
+ << "joined";
+ break;
+ case ProvisioningEvent::Type::Left:
+ Cbo << "type"
+ << "left";
+ break;
+ case ProvisioningEvent::Type::Returned:
+ Cbo << "type"
+ << "returned";
+ break;
+ }
+
+ Cbo.AddDateTime("ts", Evt.Timestamp);
+ Cbo << "worker_id" << std::string_view(Evt.WorkerId);
+ Cbo << "hostname" << std::string_view(Evt.Hostname);
+ Cbo.EndObject();
+ }
+ });
+
+ Cbo.EndArray();
+ return Cbo.Save();
+}
+
+std::string
+OrchestratorService::AnnounceClient(const ClientAnnouncement& Ann)
+{
+ ZEN_TRACE_CPU("OrchestratorService::AnnounceClient");
+
+ std::string ClientId = fmt::format("client-{}", Oid::NewOid().ToString());
+
+ bool IsNew = false;
+
+ m_KnownClientsLock.WithExclusiveLock([&] {
+ auto It = m_KnownClients.find(ClientId);
+ IsNew = (It == m_KnownClients.end());
+
+ auto& Client = m_KnownClients[ClientId];
+ Client.SessionId = Ann.SessionId;
+ Client.Hostname = Ann.Hostname;
+ if (!Ann.Address.empty())
+ {
+ Client.Address = Ann.Address;
+ }
+ if (Ann.Metadata)
+ {
+ Client.Metadata = Ann.Metadata;
+ }
+ Client.LastSeen.Reset();
+ });
+
+ if (IsNew)
+ {
+ RecordClientEvent(ClientEvent::Type::Connected, ClientId, Ann.Hostname);
+ }
+ else
+ {
+ RecordClientEvent(ClientEvent::Type::Updated, ClientId, Ann.Hostname);
+ }
+
+ return ClientId;
+}
+
+bool
+OrchestratorService::UpdateClient(std::string_view ClientId, CbObject Metadata)
+{
+ ZEN_TRACE_CPU("OrchestratorService::UpdateClient");
+
+ bool Found = false;
+
+ m_KnownClientsLock.WithExclusiveLock([&] {
+ auto It = m_KnownClients.find(std::string(ClientId));
+ if (It != m_KnownClients.end())
+ {
+ Found = true;
+ if (Metadata)
+ {
+ It->second.Metadata = std::move(Metadata);
+ }
+ It->second.LastSeen.Reset();
+ }
+ });
+
+ return Found;
+}
+
+bool
+OrchestratorService::CompleteClient(std::string_view ClientId)
+{
+ ZEN_TRACE_CPU("OrchestratorService::CompleteClient");
+
+ std::string Hostname;
+ bool Found = false;
+
+ m_KnownClientsLock.WithExclusiveLock([&] {
+ auto It = m_KnownClients.find(std::string(ClientId));
+ if (It != m_KnownClients.end())
+ {
+ Found = true;
+ Hostname = It->second.Hostname;
+ m_KnownClients.erase(It);
+ }
+ });
+
+ if (Found)
+ {
+ RecordClientEvent(ClientEvent::Type::Disconnected, ClientId, Hostname);
+ }
+
+ return Found;
+}
+
+CbObject
+OrchestratorService::GetClientList()
+{
+ ZEN_TRACE_CPU("OrchestratorService::GetClientList");
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("clients");
+
+ m_KnownClientsLock.WithSharedLock([&] {
+ for (const auto& [ClientId, Client] : m_KnownClients)
+ {
+ Cbo.BeginObject();
+ Cbo << "id" << ClientId;
+ if (Client.SessionId)
+ {
+ Cbo << "session_id" << Client.SessionId;
+ }
+ Cbo << "hostname" << std::string_view(Client.Hostname);
+ if (!Client.Address.empty())
+ {
+ Cbo << "address" << std::string_view(Client.Address);
+ }
+ Cbo << "dt" << Client.LastSeen.GetElapsedTimeMs();
+ if (Client.Metadata)
+ {
+ Cbo << "metadata" << Client.Metadata;
+ }
+ Cbo.EndObject();
+ }
+ });
+
+ Cbo.EndArray();
+ return Cbo.Save();
+}
+
+CbObject
+OrchestratorService::GetClientHistory(int Limit)
+{
+ ZEN_TRACE_CPU("OrchestratorService::GetClientHistory");
+
+ if (Limit <= 0)
+ {
+ Limit = 100;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("client_events");
+
+ m_ClientLogLock.WithSharedLock([&] {
+ int Count = 0;
+ for (auto It = m_ClientLog.rbegin(); It != m_ClientLog.rend() && Count < Limit; ++It, ++Count)
+ {
+ const auto& Evt = *It;
+ Cbo.BeginObject();
+
+ switch (Evt.EventType)
+ {
+ case ClientEvent::Type::Connected:
+ Cbo << "type"
+ << "connected";
+ break;
+ case ClientEvent::Type::Disconnected:
+ Cbo << "type"
+ << "disconnected";
+ break;
+ case ClientEvent::Type::Updated:
+ Cbo << "type"
+ << "updated";
+ break;
+ }
+
+ Cbo.AddDateTime("ts", Evt.Timestamp);
+ Cbo << "client_id" << std::string_view(Evt.ClientId);
+ Cbo << "hostname" << std::string_view(Evt.Hostname);
+ Cbo.EndObject();
+ }
+ });
+
+ Cbo.EndArray();
+ return Cbo.Save();
+}
+
+void
+OrchestratorService::RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname)
+{
+ ClientEvent Evt{
+ .EventType = Type,
+ .Timestamp = DateTime::Now(),
+ .ClientId = std::string(ClientId),
+ .Hostname = std::string(Hostname),
+ };
+
+ m_ClientLogLock.WithExclusiveLock([&] {
+ m_ClientLog.push_back(std::move(Evt));
+ while (m_ClientLog.size() > kMaxClientEvents)
+ {
+ m_ClientLog.pop_front();
+ }
+ });
+}
+
+void
+OrchestratorService::ProbeThreadFunction()
+{
+ ZEN_TRACE_CPU("OrchestratorService::ProbeThreadFunction");
+ SetCurrentThreadName("orch_probe");
+
+ bool IsFirstProbe = true;
+
+ do
+ {
+ if (!IsFirstProbe)
+ {
+ m_ProbeThreadEvent.Wait(5'000);
+ m_ProbeThreadEvent.Reset();
+ }
+ else
+ {
+ IsFirstProbe = false;
+ }
+
+ if (m_ProbeThreadEnabled == false)
+ {
+ return;
+ }
+
+ m_ProbeThreadEvent.Reset();
+
+ // Snapshot worker IDs and URIs under shared lock
+ struct WorkerSnapshot
+ {
+ std::string Id;
+ std::string Uri;
+ bool WsConnected = false;
+ };
+ std::vector<WorkerSnapshot> Snapshots;
+
+ m_KnownWorkersLock.WithSharedLock([&] {
+ Snapshots.reserve(m_KnownWorkers.size());
+ for (const auto& [WorkerId, Worker] : m_KnownWorkers)
+ {
+ Snapshots.push_back({WorkerId, Worker.BaseUri, Worker.WsConnected});
+ }
+ });
+
+ // Probe each worker outside the lock
+ for (const auto& Snap : Snapshots)
+ {
+ if (m_ProbeThreadEnabled == false)
+ {
+ return;
+ }
+
+ // Workers with an active WebSocket connection are known-reachable;
+ // skip the HTTP health probe for them.
+ if (Snap.WsConnected)
+ {
+ continue;
+ }
+
+ ReachableState NewState = ReachableState::Unreachable;
+
+ try
+ {
+ HttpClient Client(Snap.Uri,
+ {.ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{5000}});
+ HttpClient::Response Response = Client.Get("/health/");
+ if (Response.IsSuccess())
+ {
+ NewState = ReachableState::Reachable;
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ }
+
+ ReachableState PrevState = ReachableState::Unknown;
+ std::string WorkerHostname;
+
+ m_KnownWorkersLock.WithExclusiveLock([&] {
+ auto It = m_KnownWorkers.find(Snap.Id);
+ if (It != m_KnownWorkers.end())
+ {
+ PrevState = It->second.Reachable;
+ WorkerHostname = It->second.Hostname;
+ It->second.Reachable = NewState;
+ It->second.LastProbed.Reset();
+
+ if (PrevState != NewState)
+ {
+ if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable)
+ {
+ ZEN_INFO("worker {} ({}) is reachable again", Snap.Id, Snap.Uri);
+ }
+ else if (NewState == ReachableState::Reachable)
+ {
+ ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri);
+ }
+ else if (PrevState == ReachableState::Reachable)
+ {
+ ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri);
+ }
+ else
+ {
+ ZEN_WARN("worker {} ({}) is not reachable", Snap.Id, Snap.Uri);
+ }
+ }
+ }
+ });
+
+ // Record provisioning events for state transitions outside the lock
+ if (PrevState != NewState)
+ {
+ if (NewState == ReachableState::Unreachable && PrevState == ReachableState::Reachable)
+ {
+ RecordProvisioningEvent(ProvisioningEvent::Type::Left, Snap.Id, WorkerHostname);
+ }
+ else if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable)
+ {
+ RecordProvisioningEvent(ProvisioningEvent::Type::Returned, Snap.Id, WorkerHostname);
+ }
+ }
+ }
+
+ // Sweep expired clients (5-minute timeout)
+ static constexpr int64_t kClientTimeoutMs = 5 * 60 * 1000;
+
+ struct ExpiredClient
+ {
+ std::string Id;
+ std::string Hostname;
+ };
+ std::vector<ExpiredClient> ExpiredClients;
+
+ m_KnownClientsLock.WithExclusiveLock([&] {
+ for (auto It = m_KnownClients.begin(); It != m_KnownClients.end();)
+ {
+ if (It->second.LastSeen.GetElapsedTimeMs() > kClientTimeoutMs)
+ {
+ ExpiredClients.push_back({It->first, It->second.Hostname});
+ It = m_KnownClients.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ for (const auto& Expired : ExpiredClients)
+ {
+ ZEN_INFO("client {} timed out (no announcement for >5 minutes)", Expired.Id);
+ RecordClientEvent(ClientEvent::Type::Disconnected, Expired.Id, Expired.Hostname);
+ }
+ } while (m_ProbeThreadEnabled);
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/recording/actionrecorder.cpp b/src/zencompute/recording/actionrecorder.cpp
new file mode 100644
index 000000000..90141ca55
--- /dev/null
+++ b/src/zencompute/recording/actionrecorder.cpp
@@ -0,0 +1,258 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "actionrecorder.h"
+
+#include "../runners/functionrunner.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinaryfile.h>
+#include <zencore/compactbinaryvalue.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <ppl.h>
+# define ZEN_CONCRT_AVAILABLE 1
+#else
+# define ZEN_CONCRT_AVAILABLE 0
+#endif
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+RecordingFileWriter::RecordingFileWriter()
+{
+}
+
+RecordingFileWriter::~RecordingFileWriter()
+{
+ Close();
+}
+
+void
+RecordingFileWriter::Open(std::filesystem::path FilePath)
+{
+ using namespace std::literals;
+
+ m_File.Open(FilePath, BasicFile::Mode::kTruncate);
+ m_File.Write("----DDC2----DATA", 16, 0);
+ m_FileOffset = 16;
+
+ std::filesystem::path TocPath = FilePath.replace_extension(".ztoc");
+ m_TocFile.Open(TocPath, BasicFile::Mode::kTruncate);
+
+ m_TocWriter << "version"sv << 1;
+ m_TocWriter.BeginArray("toc"sv);
+}
+
+void
+RecordingFileWriter::Close()
+{
+ m_TocWriter.EndArray();
+ CbObject Toc = m_TocWriter.Save();
+
+ std::error_code Ec;
+ m_TocFile.WriteAll(Toc.GetBuffer().AsIoBuffer(), Ec);
+}
+
+void
+RecordingFileWriter::AppendObject(const CbObject& Object, const IoHash& ObjectHash)
+{
+ RwLock::ExclusiveLockScope _(m_FileLock);
+
+ MemoryView ObjectView = Object.GetBuffer().GetView();
+
+ std::error_code Ec;
+ m_File.Write(ObjectView, m_FileOffset, Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, "failed writing to archive");
+ }
+
+ m_TocWriter.BeginArray();
+ m_TocWriter.AddHash(ObjectHash);
+ m_TocWriter.AddInteger(m_FileOffset);
+ m_TocWriter.AddInteger(gsl::narrow<int>(ObjectView.GetSize()));
+ m_TocWriter.EndArray();
+
+ m_FileOffset += ObjectView.GetSize();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ActionRecorder::ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath)
+: m_ChunkResolver(InChunkResolver)
+, m_RecordingLogDir(RecordingLogPath)
+{
+ std::error_code Ec;
+ CreateDirectories(m_RecordingLogDir, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Could not create directory '{}': {}", m_RecordingLogDir, Ec.message());
+ }
+
+ CleanDirectory(m_RecordingLogDir, /* ForceRemoveReadOnlyFiles */ true, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Could not clean directory '{}': {}", m_RecordingLogDir, Ec.message());
+ }
+
+ m_WorkersFile.Open(m_RecordingLogDir / "workers.zdat");
+ m_ActionsFile.Open(m_RecordingLogDir / "actions.zdat");
+
+ CidStoreConfiguration CidConfig;
+ CidConfig.RootDirectory = m_RecordingLogDir / "cid";
+ CidConfig.HugeValueThreshold = 128 * 1024 * 1024;
+
+ m_CidStore.Initialize(CidConfig);
+}
+
+ActionRecorder::~ActionRecorder()
+{
+ Shutdown();
+}
+
+void
+ActionRecorder::Shutdown()
+{
+ m_CidStore.Flush();
+}
+
+void
+ActionRecorder::RegisterWorker(const CbPackage& WorkerPackage)
+{
+ const IoHash WorkerId = WorkerPackage.GetObjectHash();
+
+ m_WorkersFile.AppendObject(WorkerPackage.GetObject(), WorkerId);
+
+ std::unordered_set<IoHash> AddedChunks;
+ uint64_t AddedBytes = 0;
+
+ // First add all attachments from the worker package itself
+
+ for (const CbAttachment& Attachment : WorkerPackage.GetAttachments())
+ {
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+ IoBuffer Data = Buffer.GetCompressed().Flatten().AsIoBuffer();
+
+ const IoHash ChunkHash = Buffer.DecodeRawHash();
+
+ CidStore::InsertResult Result = m_CidStore.AddChunk(Data, ChunkHash, CidStore::InsertMode::kCopyOnly);
+
+ AddedChunks.insert(ChunkHash);
+
+ if (Result.New)
+ {
+ AddedBytes += Data.GetSize();
+ }
+ }
+
+ // Not all attachments will be present in the worker package, so we need to add
+ // all referenced chunks to ensure that the recording is self-contained and not
+ // referencing data in the main CID store
+
+ CbObject WorkerDescriptor = WorkerPackage.GetObject();
+
+ WorkerDescriptor.IterateAttachments([&](const CbFieldView AttachmentField) {
+ const IoHash AttachmentCid = AttachmentField.GetValue().AsHash();
+
+ if (!AddedChunks.contains(AttachmentCid))
+ {
+ IoBuffer AttachmentData = m_ChunkResolver.FindChunkByCid(AttachmentCid);
+
+ if (AttachmentData)
+ {
+ CidStore::InsertResult Result = m_CidStore.AddChunk(AttachmentData, AttachmentCid, CidStore::InsertMode::kCopyOnly);
+
+ if (Result.New)
+ {
+ AddedBytes += AttachmentData.GetSize();
+ }
+ }
+ else
+ {
+ ZEN_WARN("RegisterWorker: could not resolve attachment chunk {} for worker {}", AttachmentCid, WorkerId);
+ }
+
+ AddedChunks.insert(AttachmentCid);
+ }
+ });
+
+ ZEN_INFO("recorded worker {} with {} attachments ({} bytes)", WorkerId, AddedChunks.size(), AddedBytes);
+}
+
+bool
+ActionRecorder::RecordAction(Ref<RunnerAction> Action)
+{
+ bool AllGood = true;
+
+ Action->ActionObj.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachData = Field.AsHash();
+ IoBuffer ChunkData = m_ChunkResolver.FindChunkByCid(AttachData);
+
+ if (ChunkData)
+ {
+ if (ChunkData.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ IoHash DecompressedHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), /* out */ DecompressedHash, /* out*/ RawSize);
+
+ OodleCompressor Compressor;
+ OodleCompressionLevel CompressionLevel;
+ uint64_t BlockSize = 0;
+ if (Compressed.TryGetCompressParameters(/* out */ Compressor, /* out */ CompressionLevel, /* out */ BlockSize))
+ {
+ if (Compressor == OodleCompressor::NotSet)
+ {
+ CompositeBuffer Decompressed = Compressed.DecompressToComposite();
+ CompressedBuffer NewCompressed = CompressedBuffer::Compress(std::move(Decompressed),
+ OodleCompressor::Mermaid,
+ OodleCompressionLevel::Fast,
+ BlockSize);
+
+ ChunkData = NewCompressed.GetCompressed().Flatten().AsIoBuffer();
+ }
+ }
+ }
+
+ const uint64_t ChunkSize = ChunkData.GetSize();
+
+ m_CidStore.AddChunk(ChunkData, AttachData, CidStore::InsertMode::kCopyOnly);
+ ++m_ChunkCounter;
+ m_ChunkBytesCounter.fetch_add(ChunkSize);
+ }
+ else
+ {
+ AllGood = false;
+
+ ZEN_WARN("could not resolve chunk {}", AttachData);
+ }
+ });
+
+ if (AllGood)
+ {
+ m_ActionsFile.AppendObject(Action->ActionObj, Action->ActionId);
+ ++m_ActionsCounter;
+
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/recording/actionrecorder.h b/src/zencompute/recording/actionrecorder.h
new file mode 100644
index 000000000..2827b6ac7
--- /dev/null
+++ b/src/zencompute/recording/actionrecorder.h
@@ -0,0 +1,91 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/computeservice.h>
+#include <zencompute/zencompute.h>
+#include <zencore/basicfile.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/gc.h>
+#include <zenstore/zenstore.h>
+
+#include <filesystem>
+#include <functional>
+#include <map>
+#include <unordered_map>
+
+namespace zen {
+class CbObject;
+class CbPackage;
+struct IoHash;
+} // namespace zen
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen::compute {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct RecordingFileWriter
+{
+ RecordingFileWriter(RecordingFileWriter&&) = delete;
+ RecordingFileWriter& operator=(RecordingFileWriter&&) = delete;
+
+ RwLock m_FileLock;
+ BasicFile m_File;
+ uint64_t m_FileOffset = 0;
+ CbObjectWriter m_TocWriter;
+ BasicFile m_TocFile;
+
+ RecordingFileWriter();
+ ~RecordingFileWriter();
+
+ void Open(std::filesystem::path FilePath);
+ void Close();
+ void AppendObject(const CbObject& Object, const IoHash& ObjectHash);
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+/**
+ * Recording "runner" implementation
+ *
+ * This class writes out all actions and their attachments to a recording directory
+ * in a format that can be read back by the RecordingReader.
+ *
+ * The contents of the recording directory will be self-contained, with all referenced
+ * attachments stored in the recording directory itself, so that the recording can be
+ * moved or shared without needing to maintain references to the main CID store.
+ *
+ */
+
+class ActionRecorder
+{
+public:
+ ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath);
+ ~ActionRecorder();
+
+ ActionRecorder(const ActionRecorder&) = delete;
+ ActionRecorder& operator=(const ActionRecorder&) = delete;
+
+ void Shutdown();
+ void RegisterWorker(const CbPackage& WorkerPackage);
+ bool RecordAction(Ref<RunnerAction> Action);
+
+private:
+ ChunkResolver& m_ChunkResolver;
+ std::filesystem::path m_RecordingLogDir;
+
+ RecordingFileWriter m_WorkersFile;
+ RecordingFileWriter m_ActionsFile;
+ GcManager m_Gc;
+ CidStore m_CidStore{m_Gc};
+ std::atomic<int> m_ChunkCounter{0};
+ std::atomic<uint64_t> m_ChunkBytesCounter{0};
+ std::atomic<int> m_ActionsCounter{0};
+};
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/recording/recordingreader.cpp b/src/zencompute/recording/recordingreader.cpp
new file mode 100644
index 000000000..1c1a119cf
--- /dev/null
+++ b/src/zencompute/recording/recordingreader.cpp
@@ -0,0 +1,335 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencompute/recordingreader.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinaryfile.h>
+#include <zencore/compactbinaryvalue.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <ppl.h>
+# define ZEN_CONCRT_AVAILABLE 1
+#else
+# define ZEN_CONCRT_AVAILABLE 0
+#endif
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+# if ZEN_PLATFORM_WINDOWS
+# define ZEN_BUILD_ACTION L"Build.action"
+# define ZEN_WORKER_UCB L"worker.ucb"
+# else
+# define ZEN_BUILD_ACTION "Build.action"
+# define ZEN_WORKER_UCB "worker.ucb"
+# endif
+
+//////////////////////////////////////////////////////////////////////////
+
+struct RecordingTreeVisitor : public FileSystemTraversal::TreeVisitor
+{
+ virtual void VisitFile(const std::filesystem::path& Parent,
+ const path_view& File,
+ uint64_t FileSize,
+ uint32_t NativeModeOrAttributes,
+ uint64_t NativeModificationTick)
+ {
+ ZEN_UNUSED(Parent, File, FileSize, NativeModeOrAttributes, NativeModificationTick);
+
+ if (File.compare(path_view(ZEN_BUILD_ACTION)) == 0)
+ {
+ WorkDirs.push_back(Parent);
+ }
+ else if (File.compare(path_view(ZEN_WORKER_UCB)) == 0)
+ {
+ WorkerDirs.push_back(Parent);
+ }
+ }
+
+ virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName, uint32_t NativeModeOrAttributes)
+ {
+ ZEN_UNUSED(Parent, DirectoryName, NativeModeOrAttributes);
+
+ return true;
+ }
+
+ std::vector<std::filesystem::path> WorkerDirs;
+ std::vector<std::filesystem::path> WorkDirs;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+IterateOverArray(auto Array, auto Func, int TargetParallelism)
+{
+# if ZEN_CONCRT_AVAILABLE
+ if (TargetParallelism > 1)
+ {
+ concurrency::simple_partitioner Chunker(Array.size() / TargetParallelism);
+ concurrency::parallel_for_each(begin(Array), end(Array), [&](const auto& Item) { Func(Item); });
+
+ return;
+ }
+# else
+ ZEN_UNUSED(TargetParallelism);
+# endif
+
+ for (const auto& Item : Array)
+ {
+ Func(Item);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+RecordingReaderBase::~RecordingReaderBase() = default;
+
+//////////////////////////////////////////////////////////////////////////
+
+RecordingReader::RecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingLogDir(RecordingPath)
+{
+ CidStoreConfiguration CidConfig;
+ CidConfig.RootDirectory = m_RecordingLogDir / "cid";
+ CidConfig.HugeValueThreshold = 128 * 1024 * 1024;
+
+ m_CidStore.Initialize(CidConfig);
+}
+
+RecordingReader::~RecordingReader()
+{
+ m_CidStore.Flush();
+}
+
+size_t
+RecordingReader::GetActionCount() const
+{
+ return m_Actions.size();
+}
+
+IoBuffer
+RecordingReader::FindChunkByCid(const IoHash& DecompressedId)
+{
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(DecompressedId))
+ {
+ return Chunk;
+ }
+
+ ZEN_ERROR("failed lookup of chunk with CID '{}'", DecompressedId);
+
+ return {};
+}
+
+std::unordered_map<zen::IoHash, zen::CbPackage>
+RecordingReader::ReadWorkers()
+{
+ std::unordered_map<zen::IoHash, zen::CbPackage> WorkerMap;
+
+ {
+ CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "workers.ztoc");
+ CbObject Toc = TocFile.Object;
+
+ m_WorkerDataFile.Open(m_RecordingLogDir / "workers.zdat", BasicFile::Mode::kRead);
+
+ ZEN_ASSERT(Toc["version"sv].AsInt32() == 1);
+
+ for (auto& It : Toc["toc"])
+ {
+ CbArrayView Entry = It.AsArrayView();
+ CbFieldViewIterator Vit = Entry.CreateViewIterator();
+
+ const IoHash WorkerId = Vit++->AsHash();
+ const uint64_t Offset = Vit++->AsInt64(0);
+ const uint64_t Size = Vit++->AsInt64(0);
+
+ IoBuffer WorkerRange = m_WorkerDataFile.ReadRange(Offset, Size);
+ CbObject WorkerDesc = LoadCompactBinaryObject(WorkerRange);
+ CbPackage& WorkerPkg = WorkerMap[WorkerId];
+ WorkerPkg.SetObject(WorkerDesc);
+
+ WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) {
+ const IoHash AttachmentCid = AttachmentField.GetValue().AsHash();
+ IoBuffer AttachmentData = m_CidStore.FindChunkByCid(AttachmentCid);
+
+ if (AttachmentData)
+ {
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize);
+ WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash));
+ }
+ });
+ }
+ }
+
+ // Scan actions as well (this should be called separately, ideally)
+
+ ScanActions();
+
+ return WorkerMap;
+}
+
+void
+RecordingReader::ScanActions()
+{
+ CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "actions.ztoc");
+ CbObject Toc = TocFile.Object;
+
+ m_ActionDataFile.Open(m_RecordingLogDir / "actions.zdat", BasicFile::Mode::kRead);
+
+ ZEN_ASSERT(Toc["version"sv].AsInt32() == 1);
+
+ for (auto& It : Toc["toc"])
+ {
+ CbArrayView ArrayEntry = It.AsArrayView();
+ CbFieldViewIterator Vit = ArrayEntry.CreateViewIterator();
+
+ ActionEntry Entry;
+ Entry.ActionId = Vit++->AsHash();
+ Entry.Offset = Vit++->AsInt64(0);
+ Entry.Size = Vit++->AsInt64(0);
+
+ m_Actions.push_back(Entry);
+ }
+}
+
+void
+RecordingReader::IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int TargetParallelism)
+{
+ IterateOverArray(
+ m_Actions,
+ [&](const ActionEntry& Entry) {
+ CbObject ActionDesc = LoadCompactBinaryObject(m_ActionDataFile.ReadRange(Entry.Offset, Entry.Size));
+
+ Callback(ActionDesc, Entry.ActionId);
+ },
+ TargetParallelism);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+IoBuffer
+LocalResolver::FindChunkByCid(const IoHash& DecompressedId)
+{
+ RwLock::SharedLockScope _(MapLock);
+ if (auto It = Attachments.find(DecompressedId); It != Attachments.end())
+ {
+ return It->second;
+ }
+
+ return {};
+}
+
+void
+LocalResolver::Add(const IoHash& Cid, IoBuffer Data)
+{
+ RwLock::ExclusiveLockScope _(MapLock);
+ Data.SetContentType(ZenContentType::kCompressedBinary);
+ Attachments[Cid] = Data;
+}
+
+///
+
+UeRecordingReader::UeRecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingDir(RecordingPath)
+{
+}
+
+UeRecordingReader::~UeRecordingReader()
+{
+}
+
+size_t
+UeRecordingReader::GetActionCount() const
+{
+ return m_WorkDirs.size();
+}
+
+IoBuffer
+UeRecordingReader::FindChunkByCid(const IoHash& DecompressedId)
+{
+ return m_LocalResolver.FindChunkByCid(DecompressedId);
+}
+
+std::unordered_map<zen::IoHash, zen::CbPackage>
+UeRecordingReader::ReadWorkers()
+{
+ std::unordered_map<zen::IoHash, zen::CbPackage> WorkerMap;
+
+ FileSystemTraversal Traversal;
+ RecordingTreeVisitor Visitor;
+ Traversal.TraverseFileSystem(m_RecordingDir, Visitor);
+
+ m_WorkDirs = std::move(Visitor.WorkDirs);
+
+ for (const std::filesystem::path& WorkerDir : Visitor.WorkerDirs)
+ {
+ CbObjectFromFile WorkerFile = LoadCompactBinaryObject(WorkerDir / "worker.ucb");
+ CbObject WorkerDesc = WorkerFile.Object;
+ const IoHash& WorkerId = WorkerFile.Hash;
+ CbPackage& WorkerPkg = WorkerMap[WorkerId];
+ WorkerPkg.SetObject(WorkerDesc);
+
+ WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) {
+ const IoHash AttachmentCid = AttachmentField.GetValue().AsHash();
+ IoBuffer AttachmentData = ReadFile(WorkerDir / "chunks" / AttachmentCid.ToHexString()).Flatten();
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize);
+ WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash));
+ });
+ }
+
+ return WorkerMap;
+}
+
+void
+UeRecordingReader::IterateActions(std::function<void(CbObject ActionObject, const IoHash& ActionId)>&& Callback, int ParallelismTarget)
+{
+ IterateOverArray(
+ m_WorkDirs,
+ [&](const std::filesystem::path& WorkDir) {
+ CbPackage WorkPackage = ReadAction(WorkDir);
+ CbObject ActionObject = WorkPackage.GetObject();
+ const IoHash& ActionId = WorkPackage.GetObjectHash();
+
+ Callback(ActionObject, ActionId);
+ },
+ ParallelismTarget);
+}
+
+CbPackage
+UeRecordingReader::ReadAction(std::filesystem::path WorkDir)
+{
+ CbPackage WorkPackage;
+ std::filesystem::path WorkDescPath = WorkDir / "Build.action";
+ CbObjectFromFile ActionFile = LoadCompactBinaryObject(WorkDescPath);
+ CbObject& ActionObject = ActionFile.Object;
+
+ WorkPackage.SetObject(ActionObject);
+
+ ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) {
+ const IoHash AttachmentCid = AttachmentField.GetValue().AsHash();
+ IoBuffer AttachmentData = ReadFile(WorkDir / "inputs" / AttachmentCid.ToHexString()).Flatten();
+
+ m_LocalResolver.Add(AttachmentCid, AttachmentData);
+
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize);
+ ZEN_ASSERT(AttachmentCid == RawHash);
+ WorkPackage.AddAttachment(CbAttachment(CompressedData, RawHash));
+ });
+
+ return WorkPackage;
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp
new file mode 100644
index 000000000..4fad2cf70
--- /dev/null
+++ b/src/zencompute/runners/deferreddeleter.cpp
@@ -0,0 +1,340 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "deferreddeleter.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/logging.h>
+# include <zencore/thread.h>
+
+# include <algorithm>
+# include <chrono>
+
+namespace zen::compute {
+
+using namespace std::chrono_literals;
+
+using Clock = std::chrono::steady_clock;
+
+// Default deferral: how long to wait before attempting deletion.
+// This gives memory-mapped file handles time to close naturally.
+static constexpr auto DeferralPeriod = 60s;
+
+// Shortened deferral after MarkReady(): the client has collected results
+// so handles should be released soon, but we still wait briefly.
+static constexpr auto ReadyGracePeriod = 5s;
+
+// Interval between retry attempts for directories that failed deletion.
+static constexpr auto RetryInterval = 5s;
+
+static constexpr int MaxRetries = 10;
+
+DeferredDirectoryDeleter::DeferredDirectoryDeleter() : m_Thread(&DeferredDirectoryDeleter::ThreadFunction, this)
+{
+}
+
+DeferredDirectoryDeleter::~DeferredDirectoryDeleter()
+{
+ Shutdown();
+}
+
+void
+DeferredDirectoryDeleter::Enqueue(int ActionLsn, std::filesystem::path Path)
+{
+ {
+ std::lock_guard Lock(m_Mutex);
+ m_Queue.push_back({ActionLsn, std::move(Path)});
+ }
+ m_Cv.notify_one();
+}
+
+void
+DeferredDirectoryDeleter::MarkReady(int ActionLsn)
+{
+ {
+ std::lock_guard Lock(m_Mutex);
+ m_ReadyLsns.push_back(ActionLsn);
+ }
+ m_Cv.notify_one();
+}
+
+void
+DeferredDirectoryDeleter::Shutdown()
+{
+ {
+ std::lock_guard Lock(m_Mutex);
+ m_Done = true;
+ }
+ m_Cv.notify_one();
+
+ if (m_Thread.joinable())
+ {
+ m_Thread.join();
+ }
+}
+
+void
+DeferredDirectoryDeleter::ThreadFunction()
+{
+ SetCurrentThreadName("ZenDirCleanup");
+
+ struct PendingEntry
+ {
+ int ActionLsn;
+ std::filesystem::path Path;
+ Clock::time_point ReadyTime;
+ int Attempts = 0;
+ };
+
+ std::vector<PendingEntry> PendingList;
+
+ auto TryDelete = [](PendingEntry& Entry) -> bool {
+ std::error_code Ec;
+ std::filesystem::remove_all(Entry.Path, Ec);
+ return !Ec;
+ };
+
+ for (;;)
+ {
+ bool Shutting = false;
+
+ // Drain the incoming queue and process MarkReady signals
+
+ {
+ std::unique_lock Lock(m_Mutex);
+
+ if (m_Queue.empty() && m_ReadyLsns.empty() && !m_Done)
+ {
+ if (PendingList.empty())
+ {
+ m_Cv.wait(Lock, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; });
+ }
+ else
+ {
+ auto NextReady = PendingList.front().ReadyTime;
+ for (const auto& Entry : PendingList)
+ {
+ if (Entry.ReadyTime < NextReady)
+ {
+ NextReady = Entry.ReadyTime;
+ }
+ }
+
+ m_Cv.wait_until(Lock, NextReady, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; });
+ }
+ }
+
+ // Move new items into PendingList with the full deferral deadline
+ auto Now = Clock::now();
+ for (auto& Entry : m_Queue)
+ {
+ PendingList.push_back({Entry.ActionLsn, std::move(Entry.Path), Now + DeferralPeriod, 0});
+ }
+ m_Queue.clear();
+
+ // Apply MarkReady: shorten ReadyTime for matching entries
+ for (int Lsn : m_ReadyLsns)
+ {
+ for (auto& Entry : PendingList)
+ {
+ if (Entry.ActionLsn == Lsn)
+ {
+ auto NewReady = Now + ReadyGracePeriod;
+ if (NewReady < Entry.ReadyTime)
+ {
+ Entry.ReadyTime = NewReady;
+ }
+ }
+ }
+ }
+ m_ReadyLsns.clear();
+
+ Shutting = m_Done;
+ }
+
+ // Process items whose deferral period has elapsed (or all items on shutdown)
+
+ auto Now = Clock::now();
+
+ for (size_t i = 0; i < PendingList.size();)
+ {
+ auto& Entry = PendingList[i];
+
+ if (!Shutting && Now < Entry.ReadyTime)
+ {
+ ++i;
+ continue;
+ }
+
+ if (TryDelete(Entry))
+ {
+ if (Entry.Attempts > 0)
+ {
+ ZEN_INFO("Retry succeeded for directory '{}'", Entry.Path);
+ }
+
+ PendingList[i] = std::move(PendingList.back());
+ PendingList.pop_back();
+ }
+ else
+ {
+ ++Entry.Attempts;
+
+ if (Entry.Attempts >= MaxRetries)
+ {
+ ZEN_WARN("Giving up on deleting '{}' after {} attempts", Entry.Path, Entry.Attempts);
+ PendingList[i] = std::move(PendingList.back());
+ PendingList.pop_back();
+ }
+ else
+ {
+ ZEN_WARN("Unable to delete directory '{}' (attempt {}), will retry", Entry.Path, Entry.Attempts);
+ Entry.ReadyTime = Now + RetryInterval;
+ ++i;
+ }
+ }
+ }
+
+ // Exit once shutdown is requested and nothing remains
+
+ if (Shutting && PendingList.empty())
+ {
+ return;
+ }
+ }
+}
+
+} // namespace zen::compute
+
+#endif
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/testing.h>
+
+namespace zen::compute {
+
+void
+deferreddeleter_forcelink()
+{
+}
+
+} // namespace zen::compute
+
+#endif
+
+#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/testutils.h>
+
+namespace zen::compute {
+
+TEST_SUITE_BEGIN("compute.deferreddeleter");
+
+TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory")
+{
+ ScopedTemporaryDirectory TempDir;
+ std::filesystem::path DirToDelete = TempDir.Path() / "subdir";
+ CreateDirectories(DirToDelete / "nested");
+
+ CHECK(std::filesystem::exists(DirToDelete));
+
+ {
+ DeferredDirectoryDeleter Deleter;
+ Deleter.Enqueue(1, DirToDelete);
+ }
+
+ CHECK(!std::filesystem::exists(DirToDelete));
+}
+
+TEST_CASE("DeferredDirectoryDeleter.DeletesMultipleDirectories")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ constexpr int NumDirs = 10;
+ std::vector<std::filesystem::path> Dirs;
+
+ for (int i = 0; i < NumDirs; ++i)
+ {
+ auto Dir = TempDir.Path() / std::to_string(i);
+ CreateDirectories(Dir / "child");
+ Dirs.push_back(std::move(Dir));
+ }
+
+ {
+ DeferredDirectoryDeleter Deleter;
+ for (int i = 0; i < NumDirs; ++i)
+ {
+ CHECK(std::filesystem::exists(Dirs[i]));
+ Deleter.Enqueue(100 + i, Dirs[i]);
+ }
+ }
+
+ for (const auto& Dir : Dirs)
+ {
+ CHECK(!std::filesystem::exists(Dir));
+ }
+}
+
+TEST_CASE("DeferredDirectoryDeleter.ShutdownIsIdempotent")
+{
+ ScopedTemporaryDirectory TempDir;
+ std::filesystem::path Dir = TempDir.Path() / "idempotent";
+ CreateDirectories(Dir);
+
+ DeferredDirectoryDeleter Deleter;
+ Deleter.Enqueue(42, Dir);
+ Deleter.Shutdown();
+ Deleter.Shutdown();
+
+ CHECK(!std::filesystem::exists(Dir));
+}
+
+TEST_CASE("DeferredDirectoryDeleter.HandlesNonExistentPath")
+{
+ ScopedTemporaryDirectory TempDir;
+ std::filesystem::path NoSuchDir = TempDir.Path() / "does_not_exist";
+
+ {
+ DeferredDirectoryDeleter Deleter;
+ Deleter.Enqueue(99, NoSuchDir);
+ }
+}
+
+TEST_CASE("DeferredDirectoryDeleter.ExplicitShutdownBeforeDestruction")
+{
+ ScopedTemporaryDirectory TempDir;
+ std::filesystem::path Dir = TempDir.Path() / "explicit";
+ CreateDirectories(Dir / "inner");
+
+ DeferredDirectoryDeleter Deleter;
+ Deleter.Enqueue(7, Dir);
+ Deleter.Shutdown();
+
+ CHECK(!std::filesystem::exists(Dir));
+}
+
+TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral")
+{
+ ScopedTemporaryDirectory TempDir;
+ std::filesystem::path Dir = TempDir.Path() / "markready";
+ CreateDirectories(Dir / "child");
+
+ DeferredDirectoryDeleter Deleter;
+ Deleter.Enqueue(50, Dir);
+
+ // Without MarkReady the full deferral (60s) would apply.
+ // MarkReady shortens it to 5s, and shutdown bypasses even that.
+ Deleter.MarkReady(50);
+ Deleter.Shutdown();
+
+ CHECK(!std::filesystem::exists(Dir));
+}
+
+TEST_SUITE_END();
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/runners/deferreddeleter.h b/src/zencompute/runners/deferreddeleter.h
new file mode 100644
index 000000000..9b010aa0f
--- /dev/null
+++ b/src/zencompute/runners/deferreddeleter.h
@@ -0,0 +1,68 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencompute/computeservice.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <condition_variable>
+# include <deque>
+# include <filesystem>
+# include <mutex>
+# include <thread>
+# include <vector>
+
+namespace zen::compute {
+
+/// Deletes directories on a background thread to avoid blocking callers.
+/// Useful when DeleteDirectories may stall (e.g. Wine's deferred-unlink semantics).
+///
+/// Enqueued directories wait for a deferral period before deletion, giving
+/// file handles time to close. Call MarkReady() with the ActionLsn to shorten
+/// the wait to a brief grace period (e.g. once a client has collected results).
+/// On shutdown, all pending directories are deleted immediately.
+class DeferredDirectoryDeleter
+{
+ DeferredDirectoryDeleter(const DeferredDirectoryDeleter&) = delete;
+ DeferredDirectoryDeleter& operator=(const DeferredDirectoryDeleter&) = delete;
+
+public:
+ DeferredDirectoryDeleter();
+ ~DeferredDirectoryDeleter();
+
+ /// Enqueue a directory for deferred deletion, associated with an action LSN.
+ void Enqueue(int ActionLsn, std::filesystem::path Path);
+
+ /// Signal that the action result has been consumed and the directory
+ /// can be deleted after a short grace period instead of the full deferral.
+ void MarkReady(int ActionLsn);
+
+ /// Drain the queue and join the background thread. Idempotent.
+ void Shutdown();
+
+private:
+ struct QueueEntry
+ {
+ int ActionLsn;
+ std::filesystem::path Path;
+ };
+
+ std::mutex m_Mutex;
+ std::condition_variable m_Cv;
+ std::deque<QueueEntry> m_Queue;
+ std::vector<int> m_ReadyLsns;
+ bool m_Done = false;
+ std::thread m_Thread;
+ void ThreadFunction();
+};
+
+} // namespace zen::compute
+
+#endif
+
+#if ZEN_WITH_TESTS
+namespace zen::compute {
+void deferreddeleter_forcelink(); // internal
+} // namespace zen::compute
+#endif
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
new file mode 100644
index 000000000..768cdf1e1
--- /dev/null
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -0,0 +1,365 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "functionrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/filesystem.h>
+# include <zencore/trace.h>
+
+# include <fmt/format.h>
+# include <vector>
+
+namespace zen::compute {
+
+FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions")
+{
+}
+
+FunctionRunner::~FunctionRunner() = default;
+
+size_t
+FunctionRunner::QueryCapacity()
+{
+ return 1;
+}
+
+std::vector<SubmitResult>
+FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ std::vector<SubmitResult> Results;
+ Results.reserve(Actions.size());
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ Results.push_back(SubmitAction(Action));
+ }
+
+ return Results;
+}
+
+void
+FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject)
+{
+ if (m_DumpActions)
+ {
+ std::string UniqueId = fmt::format("{}.ddb", ActionLsn);
+ std::filesystem::path Path = m_ActionsPath / UniqueId;
+
+ zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer()));
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+BaseRunnerGroup::AddRunnerInternal(FunctionRunner* Runner)
+{
+ m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); });
+}
+
+size_t
+BaseRunnerGroup::QueryCapacity()
+{
+ size_t TotalCapacity = 0;
+ m_RunnersLock.WithSharedLock([&] {
+ for (const auto& Runner : m_Runners)
+ {
+ TotalCapacity += Runner->QueryCapacity();
+ }
+ });
+ return TotalCapacity;
+}
+
+SubmitResult
+BaseRunnerGroup::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("BaseRunnerGroup::SubmitAction");
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire);
+ int Index = InitialIndex;
+ const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+
+ if (RunnerCount == 0)
+ {
+ return {.IsAccepted = false, .Reason = "No runners available"};
+ }
+
+ do
+ {
+ while (Index >= RunnerCount)
+ {
+ Index -= RunnerCount;
+ }
+
+ auto& Runner = m_Runners[Index++];
+
+ SubmitResult Result = Runner->SubmitAction(Action);
+
+ if (Result.IsAccepted == true)
+ {
+ m_NextSubmitIndex = Index % RunnerCount;
+
+ return Result;
+ }
+
+ while (Index >= RunnerCount)
+ {
+ Index -= RunnerCount;
+ }
+ } while (Index != InitialIndex);
+
+ return {.IsAccepted = false};
+}
+
+std::vector<SubmitResult>
+BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions");
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+
+ if (RunnerCount == 0)
+ {
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
+ }
+
+ // Query capacity per runner and compute total
+ std::vector<size_t> Capacities(RunnerCount);
+ size_t TotalCapacity = 0;
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ Capacities[i] = m_Runners[i]->QueryCapacity();
+ TotalCapacity += Capacities[i];
+ }
+
+ if (TotalCapacity == 0)
+ {
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"});
+ }
+
+ // Distribute actions across runners proportionally to their available capacity
+ std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount);
+ std::vector<size_t> ActionRunnerIndex(Actions.size());
+ size_t ActionIdx = 0;
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (Capacities[i] == 0)
+ {
+ continue;
+ }
+
+ size_t Share = (Actions.size() * Capacities[i] + TotalCapacity - 1) / TotalCapacity;
+ Share = std::min(Share, Capacities[i]);
+
+ for (size_t j = 0; j < Share && ActionIdx < Actions.size(); ++j, ++ActionIdx)
+ {
+ PerRunnerActions[i].push_back(Actions[ActionIdx]);
+ ActionRunnerIndex[ActionIdx] = i;
+ }
+ }
+
+ // Assign any remaining actions to runners with capacity (round-robin)
+ for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount)
+ {
+ if (Capacities[i] > PerRunnerActions[i].size())
+ {
+ PerRunnerActions[i].push_back(Actions[ActionIdx]);
+ ActionRunnerIndex[ActionIdx] = i;
+ ++ActionIdx;
+ }
+ }
+
+ // Submit batches per runner
+ std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount);
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]);
+ }
+ }
+
+ // Reassemble results in original action order
+ std::vector<SubmitResult> Results(Actions.size());
+ std::vector<size_t> PerRunnerIdx(RunnerCount, 0);
+
+ for (size_t i = 0; i < Actions.size(); ++i)
+ {
+ size_t RunnerIdx = ActionRunnerIndex[i];
+ size_t Idx = PerRunnerIdx[RunnerIdx]++;
+ Results[i] = std::move(PerRunnerResults[RunnerIdx][Idx]);
+ }
+
+ return Results;
+}
+
+size_t
+BaseRunnerGroup::GetSubmittedActionCount()
+{
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ size_t TotalCount = 0;
+
+ for (const auto& Runner : m_Runners)
+ {
+ TotalCount += Runner->GetSubmittedActionCount();
+ }
+
+ return TotalCount;
+}
+
+void
+BaseRunnerGroup::RegisterWorker(CbPackage Worker)
+{
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ for (auto& Runner : m_Runners)
+ {
+ Runner->RegisterWorker(Worker);
+ }
+}
+
+void
+BaseRunnerGroup::Shutdown()
+{
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ for (auto& Runner : m_Runners)
+ {
+ Runner->Shutdown();
+ }
+}
+
+bool
+BaseRunnerGroup::CancelAction(int ActionLsn)
+{
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ for (auto& Runner : m_Runners)
+ {
+ if (Runner->CancelAction(ActionLsn))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void
+BaseRunnerGroup::CancelRemoteQueue(int QueueId)
+{
+ RwLock::SharedLockScope _(m_RunnersLock);
+
+ for (auto& Runner : m_Runners)
+ {
+ Runner->CancelRemoteQueue(QueueId);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+RunnerAction::RunnerAction(ComputeServiceSession* OwnerSession) : m_OwnerSession(OwnerSession)
+{
+ this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks();
+}
+
+RunnerAction::~RunnerAction()
+{
+}
+
+bool
+RunnerAction::ResetActionStateToPending()
+{
+ // Only allow reset from Failed or Abandoned states
+ State CurrentState = m_ActionState.load();
+
+ if (CurrentState != State::Failed && CurrentState != State::Abandoned)
+ {
+ return false;
+ }
+
+ if (!m_ActionState.compare_exchange_strong(CurrentState, State::Pending))
+ {
+ return false;
+ }
+
+ // Clear timestamps from Submitting through _Count
+ for (int i = static_cast<int>(State::Submitting); i < static_cast<int>(State::_Count); ++i)
+ {
+ this->Timestamps[i] = 0;
+ }
+
+ // Record new Pending timestamp
+ this->Timestamps[static_cast<int>(State::Pending)] = DateTime::Now().GetTicks();
+
+ // Clear execution fields
+ ExecutionLocation.clear();
+ CpuUsagePercent.store(-1.0f, std::memory_order_relaxed);
+ CpuSeconds.store(0.0f, std::memory_order_relaxed);
+
+ // Increment retry count
+ RetryCount.fetch_add(1, std::memory_order_relaxed);
+
+ // Re-enter the scheduler pipeline
+ m_OwnerSession->PostUpdate(this);
+
+ return true;
+}
+
+void
+RunnerAction::SetActionState(State NewState)
+{
+ ZEN_ASSERT(NewState < State::_Count);
+ this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks();
+
+ do
+ {
+ if (State CurrentState = m_ActionState.load(); CurrentState == NewState)
+ {
+ // No state change
+ return;
+ }
+ else
+ {
+ if (NewState <= CurrentState)
+ {
+ // Cannot transition to an earlier or same state
+ return;
+ }
+
+ if (m_ActionState.compare_exchange_strong(CurrentState, NewState))
+ {
+ // Successful state change
+
+ m_OwnerSession->PostUpdate(this);
+
+ return;
+ }
+ }
+ } while (true);
+}
+
+void
+RunnerAction::SetResult(CbPackage&& Result)
+{
+ m_Result = std::move(Result);
+}
+
+CbPackage&
+RunnerAction::GetResult()
+{
+ ZEN_ASSERT(IsCompleted());
+ return m_Result;
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
new file mode 100644
index 000000000..f67414dbb
--- /dev/null
+++ b/src/zencompute/runners/functionrunner.h
@@ -0,0 +1,214 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencompute/computeservice.h>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <atomic>
+# include <filesystem>
+# include <vector>
+
+namespace zen::compute {
+
+struct SubmitResult
+{
+ bool IsAccepted = false;
+ std::string Reason;
+};
+
+/** Base interface for classes implementing a remote execution "runner"
+ */
+class FunctionRunner : public RefCounted
+{
+ FunctionRunner(FunctionRunner&&) = delete;
+ FunctionRunner& operator=(FunctionRunner&&) = delete;
+
+public:
+ FunctionRunner(std::filesystem::path BasePath);
+ virtual ~FunctionRunner() = 0;
+
+ virtual void Shutdown() = 0;
+ virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0;
+
+ [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0;
+ [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0;
+ [[nodiscard]] virtual bool IsHealthy() = 0;
+ [[nodiscard]] virtual size_t QueryCapacity();
+ [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+
+ // Best-effort cancellation of a specific in-flight action. Returns true if the
+ // cancellation signal was successfully sent. The action will transition to Cancelled
+ // asynchronously once the platform-level termination completes.
+ virtual bool CancelAction(int /*ActionLsn*/) { return false; }
+
+ // Cancel the remote queue corresponding to the given local QueueId.
+ // Only meaningful for remote runners; local runners ignore this.
+ virtual void CancelRemoteQueue(int /*QueueId*/) {}
+
+protected:
+ std::filesystem::path m_ActionsPath;
+ bool m_DumpActions = false;
+ void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject);
+};
+
+/** Base class for RunnerGroup that operates on generic FunctionRunner references.
+ * All scheduling, capacity, and lifecycle logic lives here.
+ */
+class BaseRunnerGroup
+{
+public:
+ size_t QueryCapacity();
+ SubmitResult SubmitAction(Ref<RunnerAction> Action);
+ std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+ size_t GetSubmittedActionCount();
+ void RegisterWorker(CbPackage Worker);
+ void Shutdown();
+ bool CancelAction(int ActionLsn);
+ void CancelRemoteQueue(int QueueId);
+
+ size_t GetRunnerCount()
+ {
+ return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); });
+ }
+
+protected:
+ void AddRunnerInternal(FunctionRunner* Runner);
+
+ RwLock m_RunnersLock;
+ std::vector<Ref<FunctionRunner>> m_Runners;
+ std::atomic<int> m_NextSubmitIndex{0};
+};
+
+/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal.
+ */
+template<typename RunnerType>
+struct RunnerGroup : public BaseRunnerGroup
+{
+ void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); }
+
+ template<typename Predicate>
+ size_t RemoveRunnerIf(Predicate&& Pred)
+ {
+ size_t RemovedCount = 0;
+ m_RunnersLock.WithExclusiveLock([&] {
+ auto It = m_Runners.begin();
+ while (It != m_Runners.end())
+ {
+ if (Pred(static_cast<RunnerType&>(**It)))
+ {
+ (*It)->Shutdown();
+ It = m_Runners.erase(It);
+ ++RemovedCount;
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+ return RemovedCount;
+ }
+};
+
+/**
+ * This represents an action going through different stages of scheduling and execution.
+ */
+struct RunnerAction : public RefCounted
+{
+ explicit RunnerAction(ComputeServiceSession* OwnerSession);
+ ~RunnerAction();
+
+ int ActionLsn = 0;
+ int QueueId = 0;
+ WorkerDesc Worker;
+ IoHash ActionId;
+ CbObject ActionObj;
+ int Priority = 0;
+ std::string ExecutionLocation; // "local" or remote hostname
+
+ // CPU usage and total CPU time of the running process, sampled periodically by the local runner.
+ // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage.
+ // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled.
+ std::atomic<float> CpuUsagePercent{-1.0f};
+ std::atomic<float> CpuSeconds{0.0f};
+ std::atomic<int> RetryCount{0};
+
+ enum class State
+ {
+ New,
+ Pending,
+ Submitting,
+ Running,
+ Completed,
+ Failed,
+ Abandoned,
+ Cancelled,
+ _Count
+ };
+
+ static const char* ToString(State _)
+ {
+ switch (_)
+ {
+ case State::New:
+ return "New";
+ case State::Pending:
+ return "Pending";
+ case State::Submitting:
+ return "Submitting";
+ case State::Running:
+ return "Running";
+ case State::Completed:
+ return "Completed";
+ case State::Failed:
+ return "Failed";
+ case State::Abandoned:
+ return "Abandoned";
+ case State::Cancelled:
+ return "Cancelled";
+ default:
+ return "Unknown";
+ }
+ }
+
+ static State FromString(std::string_view Name, State Default = State::Failed)
+ {
+ for (int i = 0; i < static_cast<int>(State::_Count); ++i)
+ {
+ if (Name == ToString(static_cast<State>(i)))
+ {
+ return static_cast<State>(i);
+ }
+ }
+ return Default;
+ }
+
+ uint64_t Timestamps[static_cast<int>(State::_Count)] = {};
+
+ State ActionState() const { return m_ActionState; }
+ void SetActionState(State NewState);
+
+ bool IsSuccess() const { return ActionState() == State::Completed; }
+ bool ResetActionStateToPending();
+ bool IsCompleted() const
+ {
+ return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned ||
+ ActionState() == State::Cancelled;
+ }
+
+ void SetResult(CbPackage&& Result);
+ CbPackage& GetResult();
+
+ ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; }
+
+private:
+ std::atomic<State> m_ActionState = State::New;
+ ComputeServiceSession* m_OwnerSession = nullptr;
+ CbPackage m_Result;
+};
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file
diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp
new file mode 100644
index 000000000..e79a6c90f
--- /dev/null
+++ b/src/zencompute/runners/linuxrunner.cpp
@@ -0,0 +1,734 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "linuxrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+
+# include <fcntl.h>
+# include <sched.h>
+# include <signal.h>
+# include <sys/mount.h>
+# include <sys/stat.h>
+# include <sys/syscall.h>
+# include <sys/wait.h>
+# include <unistd.h>
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+namespace {
+
+ // All helper functions in this namespace are async-signal-safe (safe to call
+ // between fork() and execve()). They use only raw syscalls and avoid any
+ // heap allocation, stdio, or other non-AS-safe operations.
+
+ void WriteToFd(int Fd, const char* Buf, size_t Len)
+ {
+ while (Len > 0)
+ {
+ ssize_t Written = write(Fd, Buf, Len);
+ if (Written <= 0)
+ {
+ break;
+ }
+ Buf += Written;
+ Len -= static_cast<size_t>(Written);
+ }
+ }
+
+ [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno)
+ {
+ // Write the message prefix
+ size_t MsgLen = 0;
+ for (const char* P = Msg; *P; ++P)
+ {
+ ++MsgLen;
+ }
+ WriteToFd(ErrorPipeFd, Msg, MsgLen);
+
+ // Append ": " and the errno string if non-zero
+ if (Errno != 0)
+ {
+ WriteToFd(ErrorPipeFd, ": ", 2);
+ const char* ErrStr = strerror(Errno);
+ size_t ErrLen = 0;
+ for (const char* P = ErrStr; *P; ++P)
+ {
+ ++ErrLen;
+ }
+ WriteToFd(ErrorPipeFd, ErrStr, ErrLen);
+ }
+
+ _exit(127);
+ }
+
+ int MkdirIfNeeded(const char* Path, mode_t Mode)
+ {
+ if (mkdir(Path, Mode) != 0 && errno != EEXIST)
+ {
+ return -1;
+ }
+ return 0;
+ }
+
+ int BindMountReadOnly(const char* Src, const char* Dst)
+ {
+ if (mount(Src, Dst, nullptr, MS_BIND | MS_REC, nullptr) != 0)
+ {
+ return -1;
+ }
+
+ // Remount read-only
+ if (mount(nullptr, Dst, nullptr, MS_REMOUNT | MS_BIND | MS_RDONLY | MS_REC, nullptr) != 0)
+ {
+ return -1;
+ }
+
+ return 0;
+ }
+
+ // Set up namespace-based sandbox isolation in the child process.
+ // This is called after fork(), before execve(). All operations must be
+ // async-signal-safe.
+ //
+ // The sandbox layout after pivot_root:
+ // / -> the sandbox directory (tmpfs-like, was SandboxPath)
+ // /usr -> bind-mount of host /usr (read-only)
+ // /lib -> bind-mount of host /lib (read-only)
+ // /lib64 -> bind-mount of host /lib64 (read-only, optional)
+ // /etc -> bind-mount of host /etc (read-only)
+ // /worker -> bind-mount of worker directory (read-only)
+ // /proc -> proc filesystem
+ // /dev -> tmpfs with null, zero, urandom
+ void SetupNamespaceSandbox(const char* SandboxPath, uid_t Uid, gid_t Gid, const char* WorkerPath, int ErrorPipeFd)
+ {
+ // 1. Unshare user, mount, and network namespaces
+ if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "unshare() failed", errno);
+ }
+
+ // 2. Write UID/GID mappings
+ // Must deny setgroups first (required by kernel for unprivileged user namespaces)
+ {
+ int Fd = open("/proc/self/setgroups", O_WRONLY);
+ if (Fd >= 0)
+ {
+ WriteToFd(Fd, "deny", 4);
+ close(Fd);
+ }
+ // setgroups file may not exist on older kernels; not fatal
+ }
+
+ {
+ // uid_map: map our UID to 0 inside the namespace
+ char Buf[64];
+ int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Uid));
+
+ int Fd = open("/proc/self/uid_map", O_WRONLY);
+ if (Fd < 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "open uid_map failed", errno);
+ }
+ WriteToFd(Fd, Buf, static_cast<size_t>(Len));
+ close(Fd);
+ }
+
+ {
+ // gid_map: map our GID to 0 inside the namespace
+ char Buf[64];
+ int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Gid));
+
+ int Fd = open("/proc/self/gid_map", O_WRONLY);
+ if (Fd < 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "open gid_map failed", errno);
+ }
+ WriteToFd(Fd, Buf, static_cast<size_t>(Len));
+ close(Fd);
+ }
+
+ // 3. Privatize the entire mount tree so our mounts don't propagate
+ if (mount(nullptr, "/", nullptr, MS_REC | MS_PRIVATE, nullptr) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mount MS_PRIVATE failed", errno);
+ }
+
+ // 4. Create mount points inside the sandbox and bind-mount system directories
+
+ // Helper macro-like pattern for building paths inside sandbox
+ // We use stack buffers since we can't allocate heap memory safely
+ char MountPoint[4096];
+
+ auto BuildPath = [&](const char* Suffix) -> const char* {
+ snprintf(MountPoint, sizeof(MountPoint), "%s/%s", SandboxPath, Suffix);
+ return MountPoint;
+ };
+
+ // /usr (required)
+ if (MkdirIfNeeded(BuildPath("usr"), 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/usr failed", errno);
+ }
+ if (BindMountReadOnly("/usr", BuildPath("usr")) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "bind mount /usr failed", errno);
+ }
+
+ // /lib (required)
+ if (MkdirIfNeeded(BuildPath("lib"), 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/lib failed", errno);
+ }
+ if (BindMountReadOnly("/lib", BuildPath("lib")) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno);
+ }
+
+ // /lib64 (optional — not all distros have it)
+ {
+ struct stat St;
+ if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode))
+ {
+ if (MkdirIfNeeded(BuildPath("lib64"), 0755) == 0)
+ {
+ BindMountReadOnly("/lib64", BuildPath("lib64"));
+ // Failure is non-fatal for lib64
+ }
+ }
+ }
+
+ // /etc (required — for resolv.conf, ld.so.cache, etc.)
+ if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno);
+ }
+ if (BindMountReadOnly("/etc", BuildPath("etc")) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno);
+ }
+
+ // /worker — bind-mount worker directory (contains the executable)
+ if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno);
+ }
+ if (BindMountReadOnly(WorkerPath, BuildPath("worker")) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "bind mount worker dir failed", errno);
+ }
+
+ // 5. Mount /proc inside sandbox
+ if (MkdirIfNeeded(BuildPath("proc"), 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/proc failed", errno);
+ }
+ if (mount("proc", BuildPath("proc"), "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, nullptr) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mount /proc failed", errno);
+ }
+
+ // 6. Mount tmpfs /dev and bind-mount essential device nodes
+ if (MkdirIfNeeded(BuildPath("dev"), 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/dev failed", errno);
+ }
+ if (mount("tmpfs", BuildPath("dev"), "tmpfs", MS_NOSUID | MS_NOEXEC, "size=64k,mode=0755") != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mount tmpfs /dev failed", errno);
+ }
+
+ // Bind-mount /dev/null, /dev/zero, /dev/urandom
+ {
+ char DevSrc[64];
+ char DevDst[4096];
+
+ auto BindDev = [&](const char* Name) {
+ snprintf(DevSrc, sizeof(DevSrc), "/dev/%s", Name);
+ snprintf(DevDst, sizeof(DevDst), "%s/dev/%s", SandboxPath, Name);
+
+ // Create the file to mount over
+ int Fd = open(DevDst, O_WRONLY | O_CREAT, 0666);
+ if (Fd >= 0)
+ {
+ close(Fd);
+ }
+ mount(DevSrc, DevDst, nullptr, MS_BIND, nullptr);
+ // Non-fatal if individual devices fail
+ };
+
+ BindDev("null");
+ BindDev("zero");
+ BindDev("urandom");
+ }
+
+ // 7. pivot_root to sandbox
+ // pivot_root requires the new root and put_old to be mount points.
+ // Bind-mount sandbox onto itself to make it a mount point.
+ if (mount(SandboxPath, SandboxPath, nullptr, MS_BIND | MS_REC, nullptr) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "bind mount sandbox onto itself failed", errno);
+ }
+
+ // Create .pivot_old inside sandbox
+ char PivotOld[4096];
+ snprintf(PivotOld, sizeof(PivotOld), "%s/.pivot_old", SandboxPath);
+ if (MkdirIfNeeded(PivotOld, 0755) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "mkdir .pivot_old failed", errno);
+ }
+
+ if (syscall(SYS_pivot_root, SandboxPath, PivotOld) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "pivot_root failed", errno);
+ }
+
+ // 8. Now inside new root. Clean up old root.
+ if (chdir("/") != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "chdir / failed", errno);
+ }
+
+ if (umount2("/.pivot_old", MNT_DETACH) != 0)
+ {
+ WriteErrorAndExit(ErrorPipeFd, "umount2 .pivot_old failed", errno);
+ }
+
+ rmdir("/.pivot_old");
+ }
+
+} // anonymous namespace
+
+LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ bool Sandboxed,
+ int32_t MaxConcurrentActions)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
+, m_Sandboxed(Sandboxed)
+{
+ // Restore SIGCHLD to default behavior so waitpid() can properly collect
+ // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which
+ // causes the kernel to auto-reap children, making waitpid() return
+ // -1/ECHILD instead of the exit status we need.
+ struct sigaction Action = {};
+ sigemptyset(&Action.sa_mask);
+ Action.sa_handler = SIG_DFL;
+ sigaction(SIGCHLD, &Action, nullptr);
+
+ if (m_Sandboxed)
+ {
+ ZEN_INFO("namespace sandboxing enabled for child processes");
+ }
+}
+
+SubmitResult
+LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("LinuxProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ // Build environment array from worker descriptor
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ std::vector<std::string> EnvStrings;
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ EnvStrings.emplace_back(It.AsString());
+ }
+
+ std::vector<char*> Envp;
+ Envp.reserve(EnvStrings.size() + 1);
+ for (auto& Str : EnvStrings)
+ {
+ Envp.push_back(Str.data());
+ }
+ Envp.push_back(nullptr);
+
+ // Build argv: <worker_exe_path> -Build=build.action
+ // Pre-compute all path strings before fork() for async-signal-safety.
+
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::string ExePathStr;
+ std::string SandboxedExePathStr;
+
+ if (m_Sandboxed)
+ {
+ // After pivot_root, the worker dir is at /worker inside the new root
+ std::filesystem::path SandboxedExePath = std::filesystem::path("/worker") / std::filesystem::path(ExecPath);
+ SandboxedExePathStr = SandboxedExePath.string();
+ // We still need the real path for logging
+ ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string();
+ }
+ else
+ {
+ ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string();
+ }
+
+ std::string BuildArg = "-Build=build.action";
+
+ // argv[0] should be the path the child will see
+ const std::string& ChildExePath = m_Sandboxed ? SandboxedExePathStr : ExePathStr;
+
+ std::vector<char*> ArgV;
+ ArgV.push_back(const_cast<char*>(ChildExePath.data()));
+ ArgV.push_back(BuildArg.data());
+ ArgV.push_back(nullptr);
+
+ ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed);
+
+ std::string SandboxPathStr = Prepared->SandboxPath.string();
+ std::string WorkerPathStr = Prepared->WorkerPath.string();
+
+ // Pre-fork: get uid/gid for namespace mapping, create error pipe
+ uid_t CurrentUid = 0;
+ gid_t CurrentGid = 0;
+ int ErrorPipe[2] = {-1, -1};
+
+ if (m_Sandboxed)
+ {
+ CurrentUid = getuid();
+ CurrentGid = getgid();
+
+ if (pipe2(ErrorPipe, O_CLOEXEC) != 0)
+ {
+ throw zen::runtime_error("pipe2() for sandbox error pipe failed: {}", strerror(errno));
+ }
+ }
+
+ pid_t ChildPid = fork();
+
+ if (ChildPid < 0)
+ {
+ int SavedErrno = errno;
+ if (m_Sandboxed)
+ {
+ close(ErrorPipe[0]);
+ close(ErrorPipe[1]);
+ }
+ throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno));
+ }
+
+ if (ChildPid == 0)
+ {
+ // Child process
+
+ if (m_Sandboxed)
+ {
+ // Close read end of error pipe — child only writes
+ close(ErrorPipe[0]);
+
+ SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]);
+
+ // After pivot_root, CWD is "/" which is the sandbox root.
+ // execve with the sandboxed path.
+ execve(SandboxedExePathStr.c_str(), ArgV.data(), Envp.data());
+
+ WriteErrorAndExit(ErrorPipe[1], "execve failed", errno);
+ }
+ else
+ {
+ if (chdir(SandboxPathStr.c_str()) != 0)
+ {
+ _exit(127);
+ }
+
+ execve(ExePathStr.c_str(), ArgV.data(), Envp.data());
+ _exit(127);
+ }
+ }
+
+ // Parent process
+
+ if (m_Sandboxed)
+ {
+ // Close write end of error pipe — parent only reads
+ close(ErrorPipe[1]);
+
+ // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC
+ // and read returns 0. If setup failed, child wrote an error message.
+ char ErrBuf[512];
+ ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1);
+ close(ErrorPipe[0]);
+
+ if (BytesRead > 0)
+ {
+ // Sandbox setup or execve failed
+ ErrBuf[BytesRead] = '\0';
+
+ // Reap the child (it called _exit(127))
+ waitpid(ChildPid, nullptr, 0);
+
+ // Clean up the sandbox in the background
+ m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
+
+ ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+
+ Action->SetActionState(RunnerAction::State::Failed);
+ return SubmitResult{.IsAccepted = false};
+ }
+ }
+
+ // Store child pid as void* (same convention as zencore/process.cpp)
+
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid));
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+LinuxProcessRunner::SweepRunningActions()
+{
+ ZEN_TRACE_CPU("LinuxProcessRunner::SweepRunningActions");
+ std::vector<Ref<RunningAction>> CompletedActions;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;)
+ {
+ Ref<RunningAction> Running = It->second;
+
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+ int Status = 0;
+
+ pid_t Result = waitpid(Pid, &Status, WNOHANG);
+
+ if (Result == Pid)
+ {
+ if (WIFEXITED(Status))
+ {
+ Running->ExitCode = WEXITSTATUS(Status);
+ }
+ else if (WIFSIGNALED(Status))
+ {
+ Running->ExitCode = 128 + WTERMSIG(Status);
+ }
+ else
+ {
+ Running->ExitCode = 1;
+ }
+
+ Running->ProcessHandle = nullptr;
+
+ CompletedActions.push_back(std::move(Running));
+ It = m_RunningMap.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+LinuxProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("LinuxProcessRunner::CancelRunningActions");
+ Stopwatch Timer;
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling all running actions");
+
+ // Send SIGTERM to all running processes first
+
+ for (const auto& [Lsn, Running] : RunningMap)
+ {
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+
+ if (kill(Pid, SIGTERM) != 0)
+ {
+ ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno));
+ }
+ }
+
+ // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up.
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+
+ // Poll for up to 2 seconds
+ bool Exited = false;
+ for (int i = 0; i < 20; ++i)
+ {
+ int Status = 0;
+ pid_t WaitResult = waitpid(Pid, &Status, WNOHANG);
+ if (WaitResult == Pid)
+ {
+ Exited = true;
+ ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn);
+ break;
+ }
+ usleep(100000); // 100ms
+ }
+
+ if (!Exited)
+ {
+ ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn);
+ kill(Pid, SIGKILL);
+ waitpid(Pid, nullptr, 0);
+ }
+
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+bool
+LinuxProcessRunner::CancelAction(int ActionLsn)
+{
+ ZEN_TRACE_CPU("LinuxProcessRunner::CancelAction");
+
+ // Hold the shared lock while sending the signal to prevent the sweep thread
+ // from reaping the PID (via waitpid) between our lookup and kill(). Without
+ // the lock held, the PID could be recycled by the kernel and we'd signal an
+ // unrelated process.
+ bool Sent = false;
+
+ m_RunningLock.WithSharedLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It == m_RunningMap.end())
+ {
+ return;
+ }
+
+ Ref<RunningAction> Target = It->second;
+ if (!Target->ProcessHandle)
+ {
+ return;
+ }
+
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle));
+
+ if (kill(Pid, SIGTERM) != 0)
+ {
+ ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno));
+ return;
+ }
+
+ ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid);
+ Sent = true;
+ });
+
+ // The monitor thread will pick up the process exit and mark the action as Failed.
+ return Sent;
+}
+
+static uint64_t
+ReadProcStatCpuTicks(pid_t Pid)
+{
+ char Path[64];
+ snprintf(Path, sizeof(Path), "/proc/%d/stat", static_cast<int>(Pid));
+
+ char Buf[256];
+ int Fd = open(Path, O_RDONLY);
+ if (Fd < 0)
+ {
+ return 0;
+ }
+
+ ssize_t Len = read(Fd, Buf, sizeof(Buf) - 1);
+ close(Fd);
+
+ if (Len <= 0)
+ {
+ return 0;
+ }
+
+ Buf[Len] = '\0';
+
+ // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens
+ const char* P = strrchr(Buf, ')');
+ if (!P)
+ {
+ return 0;
+ }
+
+ P += 2; // skip ') '
+
+ // Remaining fields (space-separated, 0-indexed from here):
+ // 0:state 1:ppid 2:pgrp 3:session 4:tty_nr 5:tty_pgrp 6:flags
+ // 7:minflt 8:cminflt 9:majflt 10:cmajflt 11:utime 12:stime
+ unsigned long UTime = 0;
+ unsigned long STime = 0;
+ sscanf(P, "%*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &UTime, &STime);
+ return UTime + STime;
+}
+
+void
+LinuxProcessRunner::SampleProcessCpu(RunningAction& Running)
+{
+ static const long ClkTck = sysconf(_SC_CLK_TCK);
+
+ const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle));
+
+ const uint64_t NowTicks = GetHifreqTimerValue();
+ const uint64_t CurrentOsTicks = ReadProcStatCpuTicks(Pid);
+
+ if (CurrentOsTicks == 0)
+ {
+ // Process gone or /proc entry unreadable — record timestamp without updating usage
+ Running.LastCpuSampleTicks = NowTicks;
+ Running.LastCpuOsTicks = 0;
+ return;
+ }
+
+ // Cumulative CPU seconds (absolute, available from first sample)
+ Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / ClkTck), std::memory_order_relaxed);
+
+ if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
+ {
+ const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks);
+ if (ElapsedMs > 0)
+ {
+ const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
+ const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) * 1000.0 / ClkTck / ElapsedMs * 100.0);
+ Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
+ }
+ }
+
+ Running.LastCpuSampleTicks = NowTicks;
+ Running.LastCpuOsTicks = CurrentOsTicks;
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/linuxrunner.h b/src/zencompute/runners/linuxrunner.h
new file mode 100644
index 000000000..266de366b
--- /dev/null
+++ b/src/zencompute/runners/linuxrunner.h
@@ -0,0 +1,44 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX
+
+namespace zen::compute {
+
+/** Native Linux process runner for executing Linux worker executables directly.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and monitor thread infrastructure. Overrides only the
+ platform-specific methods: process spawning, sweep, and cancellation.
+
+ When Sandboxed is true, child processes are isolated using Linux namespaces:
+ user, mount, and network namespaces are unshared so the child has no network
+ access and can only see the sandbox directory (with system libraries bind-mounted
+ read-only). This requires no special privileges thanks to user namespaces.
+ */
+class LinuxProcessRunner : public LocalProcessRunner
+{
+public:
+ LinuxProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ bool Sandboxed = false,
+ int32_t MaxConcurrentActions = 0);
+
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void SweepRunningActions() override;
+ void CancelRunningActions() override;
+ bool CancelAction(int ActionLsn) override;
+ void SampleProcessCpu(RunningAction& Running) override;
+
+private:
+ bool m_Sandboxed = false;
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
new file mode 100644
index 000000000..7aaefb06e
--- /dev/null
+++ b/src/zencompute/runners/localrunner.cpp
@@ -0,0 +1,674 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/system.h>
+# include <zencore/scopeguard.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+# include <zenstore/cidstore.h>
+
+# include <span>
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions)
+: FunctionRunner(BaseDir)
+, m_Log(logging::Get("local_exec"))
+, m_ChunkResolver(Resolver)
+, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers"))
+, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch"))
+, m_DeferredDeleter(Deleter)
+, m_WorkerPool(WorkerPool)
+{
+ SystemMetrics Sm = GetSystemMetricsForReporting();
+
+ m_MaxRunningActions = Sm.LogicalProcessorCount * 2;
+
+ if (MaxConcurrentActions > 0)
+ {
+ m_MaxRunningActions = MaxConcurrentActions;
+ }
+
+ ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions);
+
+ bool DidCleanup = false;
+
+ if (std::filesystem::is_directory(m_ActionsPath))
+ {
+ ZEN_INFO("Cleaning '{}'", m_ActionsPath);
+
+ std::error_code Ec;
+ CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message());
+ }
+
+ DidCleanup = true;
+ }
+
+ if (std::filesystem::is_directory(m_SandboxPath))
+ {
+ ZEN_INFO("Cleaning '{}'", m_SandboxPath);
+ std::error_code Ec;
+ CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message());
+ }
+
+ DidCleanup = true;
+ }
+
+ // We clean out all workers on startup since we can't know they are good. They could be bad
+ // due to tampering, malware (which I also mean to include AV and antimalware software) or
+ // other processes we have no control over
+ if (std::filesystem::is_directory(m_WorkerPath))
+ {
+ ZEN_INFO("Cleaning '{}'", m_WorkerPath);
+ std::error_code Ec;
+ CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message());
+ }
+
+ DidCleanup = true;
+ }
+
+ if (DidCleanup)
+ {
+ ZEN_INFO("Cleanup complete");
+ }
+
+ m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
+
+# if ZEN_PLATFORM_WINDOWS
+ // Suppress any error dialogs caused by missing dependencies
+ UINT OldMode = ::SetErrorMode(0);
+ ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS);
+# endif
+
+ m_AcceptNewActions = true;
+}
+
+LocalProcessRunner::~LocalProcessRunner()
+{
+ try
+ {
+ Shutdown();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception during local process runner shutdown: {}", Ex.what());
+ }
+}
+
+void
+LocalProcessRunner::Shutdown()
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::Shutdown");
+ m_AcceptNewActions = false;
+
+ m_MonitorThreadEnabled = false;
+ m_MonitorThreadEvent.Set();
+ if (m_MonitorThread.joinable())
+ {
+ m_MonitorThread.join();
+ }
+
+ CancelRunningActions();
+}
+
+std::filesystem::path
+LocalProcessRunner::CreateNewSandbox()
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox");
+ std::string UniqueId = std::to_string(++m_SandboxCounter);
+ std::filesystem::path Path = m_SandboxPath / UniqueId;
+ zen::CreateDirectories(Path);
+
+ return Path;
+}
+
+void
+LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage)
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker");
+ if (m_DumpActions)
+ {
+ CbObject WorkerDescriptor = WorkerPackage.GetObject();
+ const IoHash& WorkerId = WorkerPackage.GetObjectHash();
+
+ std::string UniqueId = fmt::format("worker_{}"sv, WorkerId);
+ std::filesystem::path Path = m_ActionsPath / UniqueId;
+
+ zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer());
+
+ ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) {
+ std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString();
+ zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed());
+ });
+
+ ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path);
+ }
+}
+
+size_t
+LocalProcessRunner::QueryCapacity()
+{
+ // Estimate how much more work we're ready to accept
+
+ RwLock::SharedLockScope _{m_RunningLock};
+
+ if (!m_AcceptNewActions)
+ {
+ return 0;
+ }
+
+ const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed);
+
+ if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions)
+ {
+ return 0;
+ }
+ else
+ {
+ return MaxRunningActions - InFlightCount;
+ }
+}
+
+std::vector<SubmitResult>
+LocalProcessRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ if (Actions.size() <= 1)
+ {
+ std::vector<SubmitResult> Results;
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ Results.push_back(SubmitAction(Action));
+ }
+
+ return Results;
+ }
+
+ // For nontrivial batches, check capacity upfront and accept what fits.
+ // Accepted actions are transitioned to Submitting and dispatched to the
+ // worker pool as fire-and-forget, so SubmitActions returns immediately
+ // and the scheduler thread is free to handle completions and updates.
+
+ size_t Available = QueryCapacity();
+
+ std::vector<SubmitResult> Results(Actions.size());
+
+ size_t AcceptCount = std::min(Available, Actions.size());
+
+ for (size_t i = 0; i < AcceptCount; ++i)
+ {
+ const Ref<RunnerAction>& Action = Actions[i];
+
+ Action->SetActionState(RunnerAction::State::Submitting);
+ m_SubmittingCount.fetch_add(1, std::memory_order_relaxed);
+
+ Results[i] = SubmitResult{.IsAccepted = true};
+
+ m_WorkerPool.ScheduleWork(
+ [this, Action]() {
+ auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); });
+
+ SubmitResult Result = SubmitAction(Action);
+
+ if (!Result.IsAccepted)
+ {
+ // This might require another state? We should
+ // distinguish between outright rejections (e.g. invalid action)
+ // and transient failures (e.g. failed to launch process) which might
+ // be retried by the scheduler, but for now just fail the action
+ Action->SetActionState(RunnerAction::State::Failed);
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+ }
+
+ for (size_t i = AcceptCount; i < Actions.size(); ++i)
+ {
+ Results[i] = SubmitResult{.IsAccepted = false};
+ }
+
+ return Results;
+}
+
+std::optional<LocalProcessRunner::PreparedAction>
+LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission");
+
+ // Verify whether we can accept more work
+
+ {
+ RwLock::SharedLockScope _{m_RunningLock};
+
+ if (!m_AcceptNewActions)
+ {
+ return std::nullopt;
+ }
+
+ if (m_RunningMap.size() >= size_t(m_MaxRunningActions))
+ {
+ return std::nullopt;
+ }
+ }
+
+ // Each enqueued action is assigned an integer index (logical sequence number),
+ // which we use as a key for tracking data structures and as an opaque id which
+ // may be used by clients to reference the scheduled action
+
+ const int32_t ActionLsn = Action->ActionLsn;
+ const CbObject& ActionObj = Action->ActionObj;
+
+ MaybeDumpAction(ActionLsn, ActionObj);
+
+ std::filesystem::path SandboxPath = CreateNewSandbox();
+
+ // Ensure the sandbox directory is cleaned up if any subsequent step throws
+ auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); });
+
+ CbPackage WorkerPackage = Action->Worker.Descriptor;
+
+ std::filesystem::path WorkerPath = ManifestWorker(Action->Worker);
+
+ // Write out action
+
+ zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer());
+
+ // Manifest inputs in sandbox
+
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Cid = Field.AsHash();
+ std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()};
+ IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid);
+
+ if (!DataBuffer)
+ {
+ throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid));
+ }
+
+ zen::WriteFile(FilePath, DataBuffer);
+ });
+
+ Action->ExecutionLocation = "local";
+
+ SandboxGuard.Dismiss();
+
+ return PreparedAction{
+ .ActionLsn = ActionLsn,
+ .SandboxPath = std::move(SandboxPath),
+ .WorkerPath = std::move(WorkerPath),
+ .WorkerPackage = std::move(WorkerPackage),
+ };
+}
+
+SubmitResult
+LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ // Base class is not directly usable — platform subclasses override this
+ ZEN_UNUSED(Action);
+ return SubmitResult{.IsAccepted = false};
+}
+
+size_t
+LocalProcessRunner::GetSubmittedActionCount()
+{
+ RwLock::SharedLockScope _(m_RunningLock);
+ return m_RunningMap.size();
+}
+
+std::filesystem::path
+LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker)
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker");
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId);
+
+ if (!std::filesystem::exists(WorkerDir))
+ {
+ _.ReleaseNow();
+
+ RwLock::ExclusiveLockScope $(m_WorkerLock);
+
+ if (!std::filesystem::exists(WorkerDir))
+ {
+ ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {});
+ }
+ }
+
+ return WorkerDir;
+}
+
+void
+LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage,
+ CbObjectView FileEntry,
+ const std::filesystem::path& SandboxRootPath,
+ std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback)
+{
+ std::string_view Name = FileEntry["name"sv].AsString();
+ const IoHash ChunkHash = FileEntry["hash"sv].AsHash();
+ const uint64_t Size = FileEntry["size"sv].AsUInt64();
+
+ CompressedBuffer Compressed;
+
+ if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash))
+ {
+ Compressed = Attachment->AsCompressedBinary();
+ }
+ else
+ {
+ IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash);
+
+ if (!DataBuffer)
+ {
+ throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash));
+ }
+
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize);
+
+ if (DataRawSize != Size)
+ {
+ throw std::runtime_error(
+ fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size));
+ }
+ }
+
+ ChunkReferenceCallback(ChunkHash, Compressed);
+
+ std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()};
+
+ // Validate the resolved path stays within the sandbox to prevent directory traversal
+ // via malicious names like "../../etc/evil"
+ //
+ // This might be worth revisiting to frontload the validation and eliminate some memory
+ // allocations in the future.
+ {
+ std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath);
+ std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath);
+ std::string RootStr = CanonicalRoot.string();
+ std::string FileStr = CanonicalFile.string();
+
+ if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0)
+ {
+ throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath);
+ }
+ }
+
+ SharedBuffer Decompressed = Compressed.Decompress();
+ zen::WriteFile(FilePath, Decompressed.AsIoBuffer());
+}
+
+void
+LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
+ const std::filesystem::path& SandboxPath,
+ std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback)
+{
+ CbObject WorkerDescription = WorkerPackage.GetObject();
+
+ // Manifest worker in Sandbox
+
+ for (auto& It : WorkerDescription["executables"sv])
+ {
+ DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback);
+# if !ZEN_PLATFORM_WINDOWS
+ std::string_view ExeName = It.AsObjectView()["name"sv].AsString();
+ std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()};
+ std::filesystem::permissions(
+ ExePath,
+ std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec,
+ std::filesystem::perm_options::add);
+# endif
+ }
+
+ for (auto& It : WorkerDescription["dirs"sv])
+ {
+ std::string_view Name = It.AsString();
+ std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()};
+
+ // Validate dir path stays within sandbox
+ {
+ std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath);
+ std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath);
+ std::string RootStr = CanonicalRoot.string();
+ std::string DirStr = CanonicalDir.string();
+
+ if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0)
+ {
+ throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath);
+ }
+ }
+
+ zen::CreateDirectories(DirPath);
+ }
+
+ for (auto& It : WorkerDescription["files"sv])
+ {
+ DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback);
+ }
+
+ WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer());
+}
+
+CbPackage
+LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath)
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs");
+ std::filesystem::path OutputFile = SandboxPath / "build.output";
+ FileContents OutputData = zen::ReadFile(OutputFile);
+
+ if (OutputData.ErrorCode)
+ {
+ throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile));
+ }
+
+ CbPackage OutputPackage;
+ CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten());
+
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalRawAttachmentBytes = 0;
+
+ Output.IterateAttachments([&](CbFieldView Field) {
+ IoHash Hash = Field.AsHash();
+ std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()};
+ FileContents ChunkData = zen::ReadFile(OutputPath);
+
+ if (ChunkData.ErrorCode)
+ {
+ throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath));
+ }
+
+ uint64_t ChunkDataRawSize = 0;
+ IoHash ChunkDataHash;
+ CompressedBuffer AttachmentBuffer =
+ CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize);
+
+ if (!AttachmentBuffer)
+ {
+ throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)");
+ }
+
+ TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize();
+ TotalRawAttachmentBytes += ChunkDataRawSize;
+
+ CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash);
+ OutputPackage.AddAttachment(Attachment);
+ });
+
+ OutputPackage.SetObject(Output);
+
+ ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)",
+ OutputPackage.GetAttachments().size(),
+ NiceBytes(TotalAttachmentBytes),
+ NiceBytes(TotalRawAttachmentBytes));
+
+ return OutputPackage;
+}
+
+void
+LocalProcessRunner::MonitorThreadFunction()
+{
+ SetCurrentThreadName("LocalProcessRunner_Monitor");
+
+ auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); });
+
+ do
+ {
+ // On Windows it's possible to wait on process handles, so we wait for either a process to exit
+ // or for the monitor event to be signaled (which indicates we should check for cancellation
+ // or shutdown). This could be further improved by using a completion port and registering process
+ // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing
+ // with an enormous number of concurrent processes.
+ //
+ // On other platforms we just wait on the monitor event and poll for process exits at intervals.
+# if ZEN_PLATFORM_WINDOWS
+ auto WaitOnce = [&] {
+ HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS];
+
+ uint32_t NumHandles = 0;
+
+ WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle();
+
+ m_RunningLock.WithSharedLock([&] {
+ for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It)
+ {
+ Ref<RunningAction> Action = It->second;
+
+ WaitHandles[NumHandles++] = Action->ProcessHandle;
+ }
+ });
+
+ DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000);
+
+ // return true if a handle was signaled
+ return (WaitResult <= NumHandles);
+ };
+# else
+ auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); };
+# endif
+
+ while (!WaitOnce())
+ {
+ if (m_MonitorThreadEnabled == false)
+ {
+ return;
+ }
+
+ SweepRunningActions();
+ SampleRunningProcessCpu();
+ }
+
+ // Signal received
+
+ SweepRunningActions();
+ SampleRunningProcessCpu();
+ } while (m_MonitorThreadEnabled);
+}
+
+void
+LocalProcessRunner::CancelRunningActions()
+{
+ // Base class is not directly usable — platform subclasses override this
+}
+
+void
+LocalProcessRunner::SampleRunningProcessCpu()
+{
+ static constexpr uint64_t kSampleIntervalMs = 5'000;
+
+ m_RunningLock.WithSharedLock([&] {
+ const uint64_t Now = GetHifreqTimerValue();
+ for (auto& [Lsn, Running] : m_RunningMap)
+ {
+ const bool NeverSampled = Running->LastCpuSampleTicks == 0;
+ const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs;
+ if (NeverSampled || IntervalElapsed)
+ {
+ SampleProcessCpu(*Running);
+ }
+ }
+ });
+}
+
+void
+LocalProcessRunner::SweepRunningActions()
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions");
+}
+
+void
+LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions)
+{
+ ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions");
+ // Shared post-processing: gather outputs, set state, clean sandbox.
+ // Note that this must be called without holding any local locks
+ // otherwise we may end up with deadlocks.
+
+ for (Ref<RunningAction> Running : CompletedActions)
+ {
+ const int ActionLsn = Running->Action->ActionLsn;
+
+ if (Running->ExitCode == 0)
+ {
+ try
+ {
+ // Gather outputs
+
+ CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath);
+
+ Running->Action->SetResult(std::move(OutputPackage));
+ Running->Action->SetActionState(RunnerAction::State::Completed);
+
+ // Enqueue sandbox for deferred background deletion, giving
+ // file handles time to close before we attempt removal.
+ m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath));
+
+ // Success -- continue with next iteration of the loop
+ continue;
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what());
+ }
+ }
+
+ // Failed - clean up the sandbox in the background.
+
+ m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h
new file mode 100644
index 000000000..7493e980b
--- /dev/null
+++ b/src/zencompute/runners/localrunner.h
@@ -0,0 +1,138 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencompute/computeservice.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "functionrunner.h"
+
+# include <zencore/thread.h>
+# include <zencore/zencore.h>
+# include <zenstore/cidstore.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/logging.h>
+
+# include "deferreddeleter.h"
+
+# include <zencore/workthreadpool.h>
+
+# include <atomic>
+# include <filesystem>
+# include <optional>
+# include <thread>
+
+namespace zen {
+class CbPackage;
+}
+
+namespace zen::compute {
+
+/** Direct process spawner
+
+ This runner simply sets up a directory structure for each job and
+ creates a process to perform the computation in it. It is not very
+ efficient and is intended mostly for testing.
+
+ */
+
+class LocalProcessRunner : public FunctionRunner
+{
+ LocalProcessRunner(LocalProcessRunner&&) = delete;
+ LocalProcessRunner& operator=(LocalProcessRunner&&) = delete;
+
+public:
+ LocalProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions = 0);
+ ~LocalProcessRunner();
+
+ virtual void Shutdown() override;
+ virtual void RegisterWorker(const CbPackage& WorkerPackage) override;
+ [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ [[nodiscard]] virtual bool IsHealthy() override { return true; }
+ [[nodiscard]] virtual size_t GetSubmittedActionCount() override;
+ [[nodiscard]] virtual size_t QueryCapacity() override;
+ [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override;
+
+protected:
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+
+ struct RunningAction : public RefCounted
+ {
+ Ref<RunnerAction> Action;
+ void* ProcessHandle = nullptr;
+ int ExitCode = 0;
+ std::filesystem::path SandboxPath;
+
+ // State for periodic CPU usage sampling
+ uint64_t LastCpuSampleTicks = 0; // hifreq timer value at last sample
+ uint64_t LastCpuOsTicks = 0; // OS CPU ticks (platform-specific units) at last sample
+ };
+
+ std::atomic_bool m_AcceptNewActions;
+ ChunkResolver& m_ChunkResolver;
+ RwLock m_WorkerLock;
+ std::filesystem::path m_WorkerPath;
+ std::atomic<int32_t> m_SandboxCounter = 0;
+ std::filesystem::path m_SandboxPath;
+ int32_t m_MaxRunningActions = 64; // arbitrary limit for testing
+
+ // if used in conjuction with m_ResultsLock, this lock must be taken *after*
+ // m_ResultsLock to avoid deadlocks
+ RwLock m_RunningLock;
+ std::unordered_map<int, Ref<RunningAction>> m_RunningMap;
+
+ std::atomic<int32_t> m_SubmittingCount = 0;
+ DeferredDirectoryDeleter& m_DeferredDeleter;
+ WorkerThreadPool& m_WorkerPool;
+
+ std::thread m_MonitorThread;
+ std::atomic<bool> m_MonitorThreadEnabled{true};
+ Event m_MonitorThreadEvent;
+ void MonitorThreadFunction();
+ virtual void SweepRunningActions();
+ virtual void CancelRunningActions();
+
+ // Sample CPU usage for all currently running processes (throttled per-action).
+ void SampleRunningProcessCpu();
+
+ // Override in platform runners to sample one process. Called under a shared RunningLock.
+ virtual void SampleProcessCpu(RunningAction& /*Running*/) {}
+
+ // Shared preamble for SubmitAction: capacity check, sandbox creation,
+ // worker manifesting, action writing, input manifesting.
+ struct PreparedAction
+ {
+ int32_t ActionLsn;
+ std::filesystem::path SandboxPath;
+ std::filesystem::path WorkerPath;
+ CbPackage WorkerPackage;
+ };
+ std::optional<PreparedAction> PrepareActionSubmission(Ref<RunnerAction> Action);
+
+ // Shared post-processing for SweepRunningActions: gather outputs,
+ // set state, clean sandbox.
+ void ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions);
+
+ std::filesystem::path CreateNewSandbox();
+ void ManifestWorker(const CbPackage& WorkerPackage,
+ const std::filesystem::path& SandboxPath,
+ std::function<void(const IoHash&, CompressedBuffer&)>&& ChunkReferenceCallback);
+ std::filesystem::path ManifestWorker(const WorkerDesc& Worker);
+ CbPackage GatherActionOutputs(std::filesystem::path SandboxPath);
+
+ void DecompressAttachmentToFile(const CbPackage& FromPackage,
+ CbObjectView FileEntry,
+ const std::filesystem::path& SandboxRootPath,
+ std::function<void(const IoHash&, CompressedBuffer&)>& ChunkReferenceCallback);
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp
new file mode 100644
index 000000000..5cec90699
--- /dev/null
+++ b/src/zencompute/runners/macrunner.cpp
@@ -0,0 +1,491 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "macrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+
+# include <fcntl.h>
+# include <libproc.h>
+# include <sandbox.h>
+# include <signal.h>
+# include <sys/wait.h>
+# include <unistd.h>
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+namespace {
+
+ // All helper functions in this namespace are async-signal-safe (safe to call
+ // between fork() and execve()). They use only raw syscalls and avoid any
+ // heap allocation, stdio, or other non-AS-safe operations.
+
+ void WriteToFd(int Fd, const char* Buf, size_t Len)
+ {
+ while (Len > 0)
+ {
+ ssize_t Written = write(Fd, Buf, Len);
+ if (Written <= 0)
+ {
+ break;
+ }
+ Buf += Written;
+ Len -= static_cast<size_t>(Written);
+ }
+ }
+
+ [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno)
+ {
+ // Write the message prefix
+ size_t MsgLen = 0;
+ for (const char* P = Msg; *P; ++P)
+ {
+ ++MsgLen;
+ }
+ WriteToFd(ErrorPipeFd, Msg, MsgLen);
+
+ // Append ": " and the errno string if non-zero
+ if (Errno != 0)
+ {
+ WriteToFd(ErrorPipeFd, ": ", 2);
+ const char* ErrStr = strerror(Errno);
+ size_t ErrLen = 0;
+ for (const char* P = ErrStr; *P; ++P)
+ {
+ ++ErrLen;
+ }
+ WriteToFd(ErrorPipeFd, ErrStr, ErrLen);
+ }
+
+ _exit(127);
+ }
+
+ // Build a Seatbelt profile string that denies everything by default and
+ // allows only the minimum needed for the worker to execute: process ops,
+ // system library reads, worker directory (read-only), and sandbox directory
+ // (read-write). Network access is denied implicitly by the deny-default policy.
+ std::string BuildSandboxProfile(const std::string& SandboxPath, const std::string& WorkerPath)
+ {
+ std::string Profile;
+ Profile.reserve(1024);
+
+ Profile += "(version 1)\n";
+ Profile += "(deny default)\n";
+ Profile += "(allow process*)\n";
+ Profile += "(allow sysctl-read)\n";
+ Profile += "(allow file-read-metadata)\n";
+
+ // System library paths needed for dynamic linker and runtime
+ Profile += "(allow file-read* (subpath \"/usr\"))\n";
+ Profile += "(allow file-read* (subpath \"/System\"))\n";
+ Profile += "(allow file-read* (subpath \"/Library\"))\n";
+ Profile += "(allow file-read* (subpath \"/dev\"))\n";
+ Profile += "(allow file-read* (subpath \"/private/var/db/dyld\"))\n";
+ Profile += "(allow file-read* (subpath \"/etc\"))\n";
+
+ // Worker directory: read-only
+ Profile += "(allow file-read* (subpath \"";
+ Profile += WorkerPath;
+ Profile += "\"))\n";
+
+ // Sandbox directory: read+write
+ Profile += "(allow file-read* file-write* (subpath \"";
+ Profile += SandboxPath;
+ Profile += "\"))\n";
+
+ return Profile;
+ }
+
+} // anonymous namespace
+
+MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ bool Sandboxed,
+ int32_t MaxConcurrentActions)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
+, m_Sandboxed(Sandboxed)
+{
+ // Restore SIGCHLD to default behavior so waitpid() can properly collect
+ // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which
+ // causes the kernel to auto-reap children, making waitpid() return
+ // -1/ECHILD instead of the exit status we need.
+ struct sigaction Action = {};
+ sigemptyset(&Action.sa_mask);
+ Action.sa_handler = SIG_DFL;
+ sigaction(SIGCHLD, &Action, nullptr);
+
+ if (m_Sandboxed)
+ {
+ ZEN_INFO("Seatbelt sandboxing enabled for child processes");
+ }
+}
+
+SubmitResult
+MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("MacProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ // Build environment array from worker descriptor
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ std::vector<std::string> EnvStrings;
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ EnvStrings.emplace_back(It.AsString());
+ }
+
+ std::vector<char*> Envp;
+ Envp.reserve(EnvStrings.size() + 1);
+ for (auto& Str : EnvStrings)
+ {
+ Envp.push_back(Str.data());
+ }
+ Envp.push_back(nullptr);
+
+ // Build argv: <worker_exe_path> -Build=build.action
+
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath);
+ std::string ExePathStr = ExePath.string();
+ std::string BuildArg = "-Build=build.action";
+
+ std::vector<char*> ArgV;
+ ArgV.push_back(ExePathStr.data());
+ ArgV.push_back(BuildArg.data());
+ ArgV.push_back(nullptr);
+
+ ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed);
+
+ std::string SandboxPathStr = Prepared->SandboxPath.string();
+ std::string WorkerPathStr = Prepared->WorkerPath.string();
+
+ // Pre-fork: build sandbox profile and create error pipe
+ std::string SandboxProfile;
+ int ErrorPipe[2] = {-1, -1};
+
+ if (m_Sandboxed)
+ {
+ SandboxProfile = BuildSandboxProfile(SandboxPathStr, WorkerPathStr);
+
+ if (pipe(ErrorPipe) != 0)
+ {
+ throw zen::runtime_error("pipe() for sandbox error pipe failed: {}", strerror(errno));
+ }
+ fcntl(ErrorPipe[0], F_SETFD, FD_CLOEXEC);
+ fcntl(ErrorPipe[1], F_SETFD, FD_CLOEXEC);
+ }
+
+ pid_t ChildPid = fork();
+
+ if (ChildPid < 0)
+ {
+ int SavedErrno = errno;
+ if (m_Sandboxed)
+ {
+ close(ErrorPipe[0]);
+ close(ErrorPipe[1]);
+ }
+ throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno));
+ }
+
+ if (ChildPid == 0)
+ {
+ // Child process
+
+ if (m_Sandboxed)
+ {
+ // Close read end of error pipe — child only writes
+ close(ErrorPipe[0]);
+
+ // Apply Seatbelt sandbox profile
+ char* ErrorBuf = nullptr;
+ if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0)
+ {
+ // sandbox_init failed — write error to pipe and exit
+ if (ErrorBuf)
+ {
+ WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0);
+ // WriteErrorAndExit does not return, but sandbox_free_error
+ // is not needed since we _exit
+ }
+ WriteErrorAndExit(ErrorPipe[1], "sandbox_init failed", errno);
+ }
+ if (ErrorBuf)
+ {
+ sandbox_free_error(ErrorBuf);
+ }
+
+ if (chdir(SandboxPathStr.c_str()) != 0)
+ {
+ WriteErrorAndExit(ErrorPipe[1], "chdir to sandbox failed", errno);
+ }
+
+ execve(ExePathStr.c_str(), ArgV.data(), Envp.data());
+
+ WriteErrorAndExit(ErrorPipe[1], "execve failed", errno);
+ }
+ else
+ {
+ if (chdir(SandboxPathStr.c_str()) != 0)
+ {
+ _exit(127);
+ }
+
+ execve(ExePathStr.c_str(), ArgV.data(), Envp.data());
+ _exit(127);
+ }
+ }
+
+ // Parent process
+
+ if (m_Sandboxed)
+ {
+ // Close write end of error pipe — parent only reads
+ close(ErrorPipe[1]);
+
+ // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC
+ // and read returns 0. If setup failed, child wrote an error message.
+ char ErrBuf[512];
+ ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1);
+ close(ErrorPipe[0]);
+
+ if (BytesRead > 0)
+ {
+ // Sandbox setup or execve failed
+ ErrBuf[BytesRead] = '\0';
+
+ // Reap the child (it called _exit(127))
+ waitpid(ChildPid, nullptr, 0);
+
+ // Clean up the sandbox in the background
+ m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
+
+ ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+
+ Action->SetActionState(RunnerAction::State::Failed);
+ return SubmitResult{.IsAccepted = false};
+ }
+ }
+
+ // Store child pid as void* (same convention as zencore/process.cpp)
+
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid));
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+MacProcessRunner::SweepRunningActions()
+{
+ ZEN_TRACE_CPU("MacProcessRunner::SweepRunningActions");
+ std::vector<Ref<RunningAction>> CompletedActions;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;)
+ {
+ Ref<RunningAction> Running = It->second;
+
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+ int Status = 0;
+
+ pid_t Result = waitpid(Pid, &Status, WNOHANG);
+
+ if (Result == Pid)
+ {
+ if (WIFEXITED(Status))
+ {
+ Running->ExitCode = WEXITSTATUS(Status);
+ }
+ else if (WIFSIGNALED(Status))
+ {
+ Running->ExitCode = 128 + WTERMSIG(Status);
+ }
+ else
+ {
+ Running->ExitCode = 1;
+ }
+
+ Running->ProcessHandle = nullptr;
+
+ CompletedActions.push_back(std::move(Running));
+ It = m_RunningMap.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+MacProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("MacProcessRunner::CancelRunningActions");
+ Stopwatch Timer;
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling all running actions");
+
+ // Send SIGTERM to all running processes first
+
+ for (const auto& [Lsn, Running] : RunningMap)
+ {
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+
+ if (kill(Pid, SIGTERM) != 0)
+ {
+ ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno));
+ }
+ }
+
+ // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up.
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+
+ // Poll for up to 2 seconds
+ bool Exited = false;
+ for (int i = 0; i < 20; ++i)
+ {
+ int Status = 0;
+ pid_t WaitResult = waitpid(Pid, &Status, WNOHANG);
+ if (WaitResult == Pid)
+ {
+ Exited = true;
+ ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn);
+ break;
+ }
+ usleep(100000); // 100ms
+ }
+
+ if (!Exited)
+ {
+ ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn);
+ kill(Pid, SIGKILL);
+ waitpid(Pid, nullptr, 0);
+ }
+
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+bool
+MacProcessRunner::CancelAction(int ActionLsn)
+{
+ ZEN_TRACE_CPU("MacProcessRunner::CancelAction");
+
+ // Hold the shared lock while sending the signal to prevent the sweep thread
+ // from reaping the PID (via waitpid) between our lookup and kill(). Without
+ // the lock held, the PID could be recycled by the kernel and we'd signal an
+ // unrelated process.
+ bool Sent = false;
+
+ m_RunningLock.WithSharedLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It == m_RunningMap.end())
+ {
+ return;
+ }
+
+ Ref<RunningAction> Target = It->second;
+ if (!Target->ProcessHandle)
+ {
+ return;
+ }
+
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle));
+
+ if (kill(Pid, SIGTERM) != 0)
+ {
+ ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno));
+ return;
+ }
+
+ ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid);
+ Sent = true;
+ });
+
+ // The monitor thread will pick up the process exit and mark the action as Failed.
+ return Sent;
+}
+
+void
+MacProcessRunner::SampleProcessCpu(RunningAction& Running)
+{
+ const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle));
+
+ struct proc_taskinfo Info;
+ if (proc_pidinfo(Pid, PROC_PIDTASKINFO, 0, &Info, sizeof(Info)) <= 0)
+ {
+ return;
+ }
+
+ // pti_total_user and pti_total_system are in nanoseconds
+ const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system;
+ const uint64_t NowTicks = GetHifreqTimerValue();
+
+ // Cumulative CPU seconds (absolute, available from first sample): ns → seconds
+ Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed);
+
+ if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
+ {
+ const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks);
+ if (ElapsedMs > 0)
+ {
+ const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
+ // ns → ms: divide by 1,000,000; then as percent of elapsed ms
+ const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0);
+ Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
+ }
+ }
+
+ Running.LastCpuSampleTicks = NowTicks;
+ Running.LastCpuOsTicks = CurrentOsTicks;
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/macrunner.h b/src/zencompute/runners/macrunner.h
new file mode 100644
index 000000000..d653b923a
--- /dev/null
+++ b/src/zencompute/runners/macrunner.h
@@ -0,0 +1,43 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC
+
+namespace zen::compute {
+
+/** Native macOS process runner for executing Mac worker executables directly.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and monitor thread infrastructure. Overrides only the
+ platform-specific methods: process spawning, sweep, and cancellation.
+
+ When Sandboxed is true, child processes are isolated using macOS Seatbelt
+ (sandbox_init): no network access and no filesystem access outside the
+ explicitly allowed sandbox and worker directories. This requires no elevation.
+ */
+class MacProcessRunner : public LocalProcessRunner
+{
+public:
+ MacProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ bool Sandboxed = false,
+ int32_t MaxConcurrentActions = 0);
+
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void SweepRunningActions() override;
+ void CancelRunningActions() override;
+ bool CancelAction(int ActionLsn) override;
+ void SampleProcessCpu(RunningAction& Running) override;
+
+private:
+ bool m_Sandboxed = false;
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp
new file mode 100644
index 000000000..672636d06
--- /dev/null
+++ b/src/zencompute/runners/remotehttprunner.cpp
@@ -0,0 +1,618 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "remotehttprunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/scopeguard.h>
+# include <zencore/system.h>
+# include <zencore/trace.h>
+# include <zenhttp/httpcommon.h>
+# include <zenstore/cidstore.h>
+
+# include <span>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
+ const std::filesystem::path& BaseDir,
+ std::string_view HostName,
+ WorkerThreadPool& InWorkerPool)
+: FunctionRunner(BaseDir)
+, m_Log(logging::Get("http_exec"))
+, m_ChunkResolver{InChunkResolver}
+, m_WorkerPool{InWorkerPool}
+, m_HostName{HostName}
+, m_BaseUrl{fmt::format("{}/compute", HostName)}
+, m_Http(m_BaseUrl)
+, m_InstanceId(Oid::NewOid())
+{
+ m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
+}
+
+RemoteHttpRunner::~RemoteHttpRunner()
+{
+ Shutdown();
+}
+
+void
+RemoteHttpRunner::Shutdown()
+{
+ // TODO: should cleanly drain/cancel pending work
+
+ m_MonitorThreadEnabled = false;
+ m_MonitorThreadEvent.Set();
+ if (m_MonitorThread.joinable())
+ {
+ m_MonitorThread.join();
+ }
+}
+
+void
+RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
+{
+ ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker");
+ const IoHash WorkerId = WorkerPackage.GetObjectHash();
+ CbPackage WorkerDesc = WorkerPackage;
+
+ std::string WorkerUrl = fmt::format("/workers/{}", WorkerId);
+
+ HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl);
+
+ if (WorkerResponse.StatusCode == HttpResponseCode::NotFound)
+ {
+ HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject());
+
+ if (DescResponse.StatusCode == HttpResponseCode::NotFound)
+ {
+ CbPackage Pkg = WorkerDesc;
+
+ // Build response package by sending only the attachments
+ // the other end needs. We start with the full package and
+ // remove the attachments which are not needed.
+
+ {
+ std::unordered_set<IoHash> Needed;
+
+ CbObject Response = DescResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ Needed.insert(NeedHash);
+ }
+
+ std::unordered_set<IoHash> ToRemove;
+
+ for (const CbAttachment& Attachment : Pkg.GetAttachments())
+ {
+ const IoHash& Hash = Attachment.GetHash();
+
+ if (Needed.find(Hash) == Needed.end())
+ {
+ ToRemove.insert(Hash);
+ }
+ }
+
+ for (const IoHash& Hash : ToRemove)
+ {
+ int RemovedCount = Pkg.RemoveAttachment(Hash);
+
+ ZEN_ASSERT(RemovedCount == 1);
+ }
+ }
+
+ // Post resulting package
+
+ HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg);
+
+ if (!IsHttpSuccessCode(PayloadResponse.StatusCode))
+ {
+ ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
+
+ // TODO: propagate error
+ }
+ }
+ else if (!IsHttpSuccessCode(DescResponse.StatusCode))
+ {
+ ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
+
+ // TODO: propagate error
+ }
+ else
+ {
+ ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent);
+ }
+ }
+ else if (WorkerResponse.StatusCode == HttpResponseCode::OK)
+ {
+ // Already known from a previous run
+ }
+ else if (!IsHttpSuccessCode(WorkerResponse.StatusCode))
+ {
+ ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})",
+ WorkerId,
+ m_Http.GetBaseUri(),
+ WorkerUrl,
+ (int)WorkerResponse.StatusCode,
+ ToString(WorkerResponse.StatusCode));
+
+ // TODO: propagate error
+ }
+}
+
+size_t
+RemoteHttpRunner::QueryCapacity()
+{
+ // Estimate how much more work we're ready to accept
+
+ RwLock::SharedLockScope _{m_RunningLock};
+
+ size_t RunningCount = m_RemoteRunningMap.size();
+
+ if (RunningCount >= size_t(m_MaxRunningActions))
+ {
+ return 0;
+ }
+
+ return m_MaxRunningActions - RunningCount;
+}
+
+std::vector<SubmitResult>
+RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions");
+
+ if (Actions.size() <= 1)
+ {
+ std::vector<SubmitResult> Results;
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ Results.push_back(SubmitAction(Action));
+ }
+
+ return Results;
+ }
+
+ // For larger batches, submit HTTP requests in parallel via the shared worker pool
+
+ std::vector<std::future<SubmitResult>> Futures;
+ Futures.reserve(Actions.size());
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); });
+
+ Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog));
+ }
+
+ std::vector<SubmitResult> Results;
+ Results.reserve(Futures.size());
+
+ for (auto& Future : Futures)
+ {
+ Results.push_back(Future.get());
+ }
+
+ return Results;
+}
+
+SubmitResult
+RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("RemoteHttpRunner::SubmitAction");
+
+ // Verify whether we can accept more work
+
+ {
+ RwLock::SharedLockScope _{m_RunningLock};
+ if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions))
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+ }
+
+ using namespace std::literals;
+
+ // Each enqueued action is assigned an integer index (logical sequence number),
+ // which we use as a key for tracking data structures and as an opaque id which
+ // may be used by clients to reference the scheduled action
+
+ Action->ExecutionLocation = m_HostName;
+
+ const int32_t ActionLsn = Action->ActionLsn;
+ const CbObject& ActionObj = Action->ActionObj;
+ const IoHash ActionId = ActionObj.GetHash();
+
+ MaybeDumpAction(ActionLsn, ActionObj);
+
+ // Determine the submission URL. If the action belongs to a queue, ensure a
+ // corresponding remote queue exists on the target node and submit via it.
+
+ std::string SubmitUrl = "/jobs";
+ if (const int QueueId = Action->QueueId; QueueId != 0)
+ {
+ CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId);
+ CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId);
+ if (Oid Token = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); Token != Oid::Zero)
+ {
+ SubmitUrl = fmt::format("/queues/{}/jobs", Token);
+ }
+ }
+
+ // Enqueue job. If the remote returns FailedDependency (424), it means it
+ // cannot resolve the worker/function — re-register the worker and retry once.
+
+ CbObject Result;
+ HttpClient::Response WorkResponse;
+ HttpResponseCode WorkResponseCode{};
+
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
+ WorkResponseCode = WorkResponse.StatusCode;
+
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
+
+ RegisterWorker(Action->Worker.Descriptor);
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ if (WorkResponseCode == HttpResponseCode::OK)
+ {
+ Result = WorkResponse.AsObject();
+ }
+ else if (WorkResponseCode == HttpResponseCode::NotFound)
+ {
+ // Not all attachments are present
+
+ // Build response package including all required attachments
+
+ CbPackage Pkg;
+ Pkg.SetObject(ActionObj);
+
+ CbObject Response = WorkResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ }
+ else
+ {
+ // No such attachment
+
+ return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ }
+ }
+
+ // Post resulting package
+
+ HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (!PayloadResponse)
+ {
+ ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+
+ // TODO: include more information about the failure in the response
+
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+ else if (PayloadResponse.StatusCode == HttpResponseCode::OK)
+ {
+ Result = PayloadResponse.AsObject();
+ }
+ else
+ {
+ // Unexpected response
+
+ const int ResponseStatusCode = (int)PayloadResponse.StatusCode;
+
+ ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})",
+ ActionId,
+ m_Http.GetBaseUri(),
+ SubmitUrl,
+ ResponseStatusCode,
+ ToString(ResponseStatusCode));
+
+ return {.IsAccepted = false,
+ .Reason = fmt::format("unexpected response code {} {} from {}{}",
+ ResponseStatusCode,
+ ToString(ResponseStatusCode),
+ m_Http.GetBaseUri(),
+ SubmitUrl)};
+ }
+ }
+
+ if (Result)
+ {
+ if (const int32_t LsnField = Result["lsn"].AsInt32(0))
+ {
+ HttpRunningAction NewAction;
+ NewAction.Action = Action;
+ NewAction.RemoteActionLsn = LsnField;
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+
+ m_RemoteRunningMap[LsnField] = std::move(NewAction);
+ }
+
+ ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn);
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ return SubmitResult{.IsAccepted = true};
+ }
+ }
+
+ return {};
+}
+
+Oid
+RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config)
+{
+ {
+ RwLock::SharedLockScope _(m_QueueTokenLock);
+ if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end())
+ {
+ return It->second;
+ }
+ }
+
+ // Build a stable idempotency key that uniquely identifies this (runner instance, local queue)
+ // pair. The server uses this to return the same remote queue token for concurrent or redundant
+ // requests, preventing orphaned remote queues when multiple threads race through here.
+ // Also send hostname so the server can associate the queue with its origin for diagnostics.
+ CbObjectWriter Body;
+ Body << "idempotency_key"sv << fmt::format("{}/{}", m_InstanceId, QueueId);
+ Body << "hostname"sv << GetMachineName();
+ if (Metadata)
+ {
+ Body << "metadata"sv << Metadata;
+ }
+ if (Config)
+ {
+ Body << "config"sv << Config;
+ }
+
+ HttpClient::Response Resp = m_Http.Post("/queues/remote", Body.Save());
+ if (!Resp)
+ {
+ ZEN_WARN("failed to create remote queue for local queue {} on {}", QueueId, m_HostName);
+ return Oid::Zero;
+ }
+
+ Oid Token = Oid::TryFromHexString(Resp.AsObject()["queue_token"sv].AsString());
+ if (Token == Oid::Zero)
+ {
+ return Oid::Zero;
+ }
+
+ ZEN_DEBUG("created remote queue '{}' for local queue {} on {}", Token, QueueId, m_HostName);
+
+ RwLock::ExclusiveLockScope _(m_QueueTokenLock);
+ auto [It, Inserted] = m_RemoteQueueTokens.try_emplace(QueueId, Token);
+ return It->second;
+}
+
+void
+RemoteHttpRunner::CancelRemoteQueue(int QueueId)
+{
+ Oid Token;
+ {
+ RwLock::SharedLockScope _(m_QueueTokenLock);
+ if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end())
+ {
+ Token = It->second;
+ }
+ }
+
+ if (Token == Oid::Zero)
+ {
+ return;
+ }
+
+ HttpClient::Response Resp = m_Http.Delete(fmt::format("/queues/{}", Token));
+
+ if (Resp.StatusCode == HttpResponseCode::NoContent)
+ {
+ ZEN_DEBUG("cancelled remote queue '{}' (local queue {}) on {}", Token, QueueId, m_HostName);
+ }
+ else
+ {
+ ZEN_WARN("failed to cancel remote queue '{}' on {}: {}", Token, m_HostName, int(Resp.StatusCode));
+ }
+}
+
+bool
+RemoteHttpRunner::IsHealthy()
+{
+ if (HttpClient::Response Ready = m_Http.Get("/ready"))
+ {
+ return true;
+ }
+ else
+ {
+ // TODO: use response to propagate context
+ return false;
+ }
+}
+
+size_t
+RemoteHttpRunner::GetSubmittedActionCount()
+{
+ RwLock::SharedLockScope _(m_RunningLock);
+ return m_RemoteRunningMap.size();
+}
+
+void
+RemoteHttpRunner::MonitorThreadFunction()
+{
+ SetCurrentThreadName("RemoteHttpRunner_Monitor");
+
+ do
+ {
+ const int NormalWaitingTime = 200;
+ int WaitTimeMs = NormalWaitingTime;
+ auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); };
+ auto SweepOnce = [&] {
+ const size_t RetiredCount = SweepRunningActions();
+
+ m_RunningLock.WithSharedLock([&] {
+ if (m_RemoteRunningMap.size() > 16)
+ {
+ WaitTimeMs = NormalWaitingTime / 4;
+ }
+ else
+ {
+ if (RetiredCount)
+ {
+ WaitTimeMs = NormalWaitingTime / 2;
+ }
+ else
+ {
+ WaitTimeMs = NormalWaitingTime;
+ }
+ }
+ });
+ };
+
+ while (!WaitOnce())
+ {
+ SweepOnce();
+ }
+
+ // Signal received - this may mean we should quit
+
+ SweepOnce();
+ } while (m_MonitorThreadEnabled);
+}
+
+size_t
+RemoteHttpRunner::SweepRunningActions()
+{
+ ZEN_TRACE_CPU("RemoteHttpRunner::SweepRunningActions");
+ std::vector<HttpRunningAction> CompletedActions;
+
+ // Poll remote for list of completed actions
+
+ HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv);
+
+ if (CbObject Completed = ResponseCompleted.AsObject())
+ {
+ for (auto& FieldIt : Completed["completed"sv])
+ {
+ CbObjectView EntryObj = FieldIt.AsObjectView();
+ const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
+ std::string_view StateName = EntryObj["state"sv].AsString();
+
+ RunnerAction::State RemoteState = RunnerAction::FromString(StateName);
+
+ // Always fetch to drain the result from the remote's results map,
+ // but only keep the result package for successfully completed actions.
+ HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn));
+
+ m_RunningLock.WithExclusiveLock([&] {
+ if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end())
+ {
+ HttpRunningAction CompletedAction = std::move(CompleteIt->second);
+ CompletedAction.RemoteState = RemoteState;
+
+ if (RemoteState == RunnerAction::State::Completed && ResponseJob)
+ {
+ CompletedAction.ActionResults = ResponseJob.AsPackage();
+ }
+
+ CompletedActions.push_back(std::move(CompletedAction));
+ m_RemoteRunningMap.erase(CompleteIt);
+ }
+ else
+ {
+ // we received a completion notice for an action we don't know about,
+ // this can happen if the runner is used by multiple upstream schedulers,
+ // or if this compute node was recently restarted and lost track of
+ // previously scheduled actions
+ }
+ });
+ }
+
+ if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView())
+ {
+ // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0))
+ if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0))
+ {
+ const int32_t NewCap = zen::Max(4, CpuCount);
+
+ if (m_MaxRunningActions > NewCap)
+ {
+ ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions);
+
+ m_MaxRunningActions = NewCap;
+ }
+ }
+ }
+ }
+
+ // Notify outer. Note that this has to be done without holding any local locks
+ // otherwise we may end up with deadlocks.
+
+ for (HttpRunningAction& HttpAction : CompletedActions)
+ {
+ const int ActionLsn = HttpAction.Action->ActionLsn;
+
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState));
+
+ if (HttpAction.RemoteState == RunnerAction::State::Completed)
+ {
+ HttpAction.Action->SetResult(std::move(HttpAction.ActionResults));
+ }
+
+ HttpAction.Action->SetActionState(HttpAction.RemoteState);
+ }
+
+ return CompletedActions.size();
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h
new file mode 100644
index 000000000..9119992a9
--- /dev/null
+++ b/src/zencompute/runners/remotehttprunner.h
@@ -0,0 +1,100 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencompute/computeservice.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "functionrunner.h"
+
+# include <zencore/compactbinarypackage.h>
+# include <zencore/logging.h>
+# include <zencore/uid.h>
+# include <zencore/workthreadpool.h>
+# include <zencore/zencore.h>
+# include <zenhttp/httpclient.h>
+
+# include <atomic>
+# include <filesystem>
+# include <thread>
+# include <unordered_map>
+
+namespace zen {
+class CidStore;
+}
+
+namespace zen::compute {
+
+/** HTTP-based runner
+
+ This implements a DDC remote compute execution strategy via REST API
+
+ */
+
+class RemoteHttpRunner : public FunctionRunner
+{
+ RemoteHttpRunner(RemoteHttpRunner&&) = delete;
+ RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete;
+
+public:
+ RemoteHttpRunner(ChunkResolver& InChunkResolver,
+ const std::filesystem::path& BaseDir,
+ std::string_view HostName,
+ WorkerThreadPool& InWorkerPool);
+ ~RemoteHttpRunner();
+
+ virtual void Shutdown() override;
+ virtual void RegisterWorker(const CbPackage& WorkerPackage) override;
+ [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ [[nodiscard]] virtual bool IsHealthy() override;
+ [[nodiscard]] virtual size_t GetSubmittedActionCount() override;
+ [[nodiscard]] virtual size_t QueryCapacity() override;
+ [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override;
+ virtual void CancelRemoteQueue(int QueueId) override;
+
+ std::string_view GetHostName() const { return m_HostName; }
+
+protected:
+ LoggerRef Log() { return m_Log; }
+
+private:
+ LoggerRef m_Log;
+ ChunkResolver& m_ChunkResolver;
+ WorkerThreadPool& m_WorkerPool;
+ std::string m_HostName;
+ std::string m_BaseUrl;
+ HttpClient m_Http;
+
+ int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+
+ struct HttpRunningAction
+ {
+ Ref<RunnerAction> Action;
+ int RemoteActionLsn = 0; // Remote LSN
+ RunnerAction::State RemoteState = RunnerAction::State::Failed;
+ CbPackage ActionResults;
+ };
+
+ RwLock m_RunningLock;
+ std::unordered_map<int, HttpRunningAction> m_RemoteRunningMap; // Note that this is keyed on the *REMOTE* lsn
+
+ std::thread m_MonitorThread;
+ std::atomic<bool> m_MonitorThreadEnabled{true};
+ Event m_MonitorThreadEvent;
+ void MonitorThreadFunction();
+ size_t SweepRunningActions();
+
+ RwLock m_QueueTokenLock;
+ std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token
+
+ // Stable identity for this runner instance, used as part of the idempotency key when
+ // creating remote queues. Generated once at construction and never changes.
+ Oid m_InstanceId;
+
+ Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config);
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp
new file mode 100644
index 000000000..e9a1ae8b6
--- /dev/null
+++ b/src/zencompute/runners/windowsrunner.cpp
@@ -0,0 +1,460 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "windowsrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/scopeguard.h>
+# include <zencore/trace.h>
+# include <zencore/system.h>
+# include <zencore/timer.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <userenv.h>
+# include <aclapi.h>
+# include <sddl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ bool Sandboxed,
+ int32_t MaxConcurrentActions)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
+, m_Sandboxed(Sandboxed)
+{
+ if (!m_Sandboxed)
+ {
+ return;
+ }
+
+ // Build a unique profile name per process to avoid collisions
+ m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
+
+ // Clean up any stale profile from a previous crash
+ DeleteAppContainerProfile(m_AppContainerName.c_str());
+
+ PSID Sid = nullptr;
+
+ HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
+ m_AppContainerName.c_str(), // display name
+ m_AppContainerName.c_str(), // description
+ nullptr, // no capabilities
+ 0, // capability count
+ &Sid);
+
+ if (FAILED(Hr))
+ {
+ throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
+ }
+
+ m_AppContainerSid = Sid;
+
+ ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+}
+
+WindowsProcessRunner::~WindowsProcessRunner()
+{
+ if (m_AppContainerSid)
+ {
+ FreeSid(m_AppContainerSid);
+ m_AppContainerSid = nullptr;
+ }
+
+ if (!m_AppContainerName.empty())
+ {
+ DeleteAppContainerProfile(m_AppContainerName.c_str());
+ }
+}
+
+void
+WindowsProcessRunner::GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask)
+{
+ PACL ExistingDacl = nullptr;
+ PSECURITY_DESCRIPTOR SecurityDescriptor = nullptr;
+
+ DWORD Result = GetNamedSecurityInfoW(Path.c_str(),
+ SE_FILE_OBJECT,
+ DACL_SECURITY_INFORMATION,
+ nullptr,
+ nullptr,
+ &ExistingDacl,
+ nullptr,
+ &SecurityDescriptor);
+
+ if (Result != ERROR_SUCCESS)
+ {
+ throw zen::runtime_error("GetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result));
+ }
+
+ auto $0 = MakeGuard([&] { LocalFree(SecurityDescriptor); });
+
+ EXPLICIT_ACCESSW Access{};
+ Access.grfAccessPermissions = AccessMask;
+ Access.grfAccessMode = SET_ACCESS;
+ Access.grfInheritance = OBJECT_INHERIT_ACE | CONTAINER_INHERIT_ACE;
+ Access.Trustee.TrusteeForm = TRUSTEE_IS_SID;
+ Access.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP;
+ Access.Trustee.ptstrName = static_cast<LPWSTR>(m_AppContainerSid);
+
+ PACL NewDacl = nullptr;
+
+ Result = SetEntriesInAclW(1, &Access, ExistingDacl, &NewDacl);
+ if (Result != ERROR_SUCCESS)
+ {
+ throw zen::runtime_error("SetEntriesInAclW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result));
+ }
+
+ auto $1 = MakeGuard([&] { LocalFree(NewDacl); });
+
+ Result = SetNamedSecurityInfoW(const_cast<LPWSTR>(Path.c_str()),
+ SE_FILE_OBJECT,
+ DACL_SECURITY_INFORMATION,
+ nullptr,
+ nullptr,
+ NewDacl,
+ nullptr);
+
+ if (Result != ERROR_SUCCESS)
+ {
+ throw zen::runtime_error("SetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result));
+ }
+}
+
+SubmitResult
+WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("WindowsProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ // Set up environment variables
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ StringBuilder<1024> EnvironmentBlock;
+
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ EnvironmentBlock.Append(It.AsString());
+ EnvironmentBlock.Append('\0');
+ }
+ EnvironmentBlock.Append('\0');
+ EnvironmentBlock.Append('\0');
+
+ // Execute process - this spawns the child process immediately without waiting
+ // for completion
+
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred();
+
+ ExtendableWideStringBuilder<512> CommandLine;
+ CommandLine.Append(L'"');
+ CommandLine.Append(ExePath.c_str());
+ CommandLine.Append(L'"');
+ CommandLine.Append(L" -Build=build.action");
+
+ LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr;
+ LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr;
+ BOOL bInheritHandles = FALSE;
+ DWORD dwCreationFlags = 0;
+
+ ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed);
+
+ CommandLine.EnsureNulTerminated();
+
+ PROCESS_INFORMATION ProcessInformation{};
+
+ if (m_Sandboxed)
+ {
+ // Grant AppContainer access to sandbox and worker directories
+ GrantAppContainerAccess(Prepared->SandboxPath, FILE_ALL_ACCESS);
+ GrantAppContainerAccess(Prepared->WorkerPath, FILE_GENERIC_READ | FILE_GENERIC_EXECUTE);
+
+ // Set up extended startup info with AppContainer security capabilities
+ SECURITY_CAPABILITIES SecurityCapabilities{};
+ SecurityCapabilities.AppContainerSid = m_AppContainerSid;
+ SecurityCapabilities.Capabilities = nullptr;
+ SecurityCapabilities.CapabilityCount = 0;
+
+ SIZE_T AttrListSize = 0;
+ InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize);
+
+ auto AttrList = static_cast<PPROC_THREAD_ATTRIBUTE_LIST>(malloc(AttrListSize));
+ auto $0 = MakeGuard([&] { free(AttrList); });
+
+ if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize))
+ {
+ zen::ThrowLastError("InitializeProcThreadAttributeList failed");
+ }
+
+ auto $1 = MakeGuard([&] { DeleteProcThreadAttributeList(AttrList); });
+
+ if (!UpdateProcThreadAttribute(AttrList,
+ 0,
+ PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES,
+ &SecurityCapabilities,
+ sizeof(SecurityCapabilities),
+ nullptr,
+ nullptr))
+ {
+ zen::ThrowLastError("UpdateProcThreadAttribute (SECURITY_CAPABILITIES) failed");
+ }
+
+ STARTUPINFOEXW StartupInfoEx{};
+ StartupInfoEx.StartupInfo.cb = sizeof(STARTUPINFOEXW);
+ StartupInfoEx.lpAttributeList = AttrList;
+
+ dwCreationFlags |= EXTENDED_STARTUPINFO_PRESENT;
+
+ BOOL Success = CreateProcessW(nullptr,
+ CommandLine.Data(),
+ lpProcessAttributes,
+ lpThreadAttributes,
+ bInheritHandles,
+ dwCreationFlags,
+ (LPVOID)EnvironmentBlock.Data(),
+ Prepared->SandboxPath.c_str(),
+ &StartupInfoEx.StartupInfo,
+ /* out */ &ProcessInformation);
+
+ if (!Success)
+ {
+ zen::ThrowLastError("Unable to launch sandboxed process");
+ }
+ }
+ else
+ {
+ STARTUPINFO StartupInfo{};
+ StartupInfo.cb = sizeof StartupInfo;
+
+ BOOL Success = CreateProcessW(nullptr,
+ CommandLine.Data(),
+ lpProcessAttributes,
+ lpThreadAttributes,
+ bInheritHandles,
+ dwCreationFlags,
+ (LPVOID)EnvironmentBlock.Data(),
+ Prepared->SandboxPath.c_str(),
+ &StartupInfo,
+ /* out */ &ProcessInformation);
+
+ if (!Success)
+ {
+ zen::ThrowLastError("Unable to launch process");
+ }
+ }
+
+ CloseHandle(ProcessInformation.hThread);
+
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = ProcessInformation.hProcess;
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+
+ m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+WindowsProcessRunner::SweepRunningActions()
+{
+ ZEN_TRACE_CPU("WindowsProcessRunner::SweepRunningActions");
+ std::vector<Ref<RunningAction>> CompletedActions;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;)
+ {
+ Ref<RunningAction> Running = It->second;
+
+ DWORD ExitCode = 0;
+ BOOL IsSuccess = GetExitCodeProcess(Running->ProcessHandle, &ExitCode);
+
+ if (IsSuccess && ExitCode != STILL_ACTIVE)
+ {
+ CloseHandle(Running->ProcessHandle);
+ Running->ProcessHandle = INVALID_HANDLE_VALUE;
+ Running->ExitCode = ExitCode;
+
+ CompletedActions.push_back(std::move(Running));
+ It = m_RunningMap.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+WindowsProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("WindowsProcessRunner::CancelRunningActions");
+ Stopwatch Timer;
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling all running actions");
+
+ // For expedience we initiate the process termination for all known
+ // processes before attempting to wait for them to exit.
+
+ // Initiate termination for all known processes before waiting for them to exit.
+
+ for (const auto& Kv : RunningMap)
+ {
+ Ref<RunningAction> Running = Kv.second;
+
+ BOOL TermSuccess = TerminateProcess(Running->ProcessHandle, 222);
+
+ if (!TermSuccess)
+ {
+ DWORD LastError = GetLastError();
+
+ if (LastError != ERROR_ACCESS_DENIED)
+ {
+ ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Running->Action->ActionLsn, GetSystemErrorAsString(LastError));
+ }
+ }
+ }
+
+ // Wait for all processes and clean up, regardless of whether TerminateProcess succeeded.
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ if (Running->ProcessHandle != INVALID_HANDLE_VALUE)
+ {
+ DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000);
+
+ if (WaitResult != WAIT_OBJECT_0)
+ {
+ ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult);
+ }
+ else
+ {
+ ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn);
+ }
+
+ CloseHandle(Running->ProcessHandle);
+ Running->ProcessHandle = INVALID_HANDLE_VALUE;
+ }
+
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+bool
+WindowsProcessRunner::CancelAction(int ActionLsn)
+{
+ ZEN_TRACE_CPU("WindowsProcessRunner::CancelAction");
+
+ // Hold the shared lock while terminating to prevent the sweep thread from
+ // closing the handle between our lookup and TerminateProcess call.
+ bool Sent = false;
+
+ m_RunningLock.WithSharedLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It == m_RunningMap.end())
+ {
+ return;
+ }
+
+ Ref<RunningAction> Target = It->second;
+ if (Target->ProcessHandle == INVALID_HANDLE_VALUE)
+ {
+ return;
+ }
+
+ BOOL TermSuccess = TerminateProcess(Target->ProcessHandle, 222);
+
+ if (!TermSuccess)
+ {
+ DWORD LastError = GetLastError();
+
+ if (LastError != ERROR_ACCESS_DENIED)
+ {
+ ZEN_WARN("CancelAction: TerminateProcess for LSN {} not successful: {}", ActionLsn, GetSystemErrorAsString(LastError));
+ }
+
+ return;
+ }
+
+ ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn);
+ Sent = true;
+ });
+
+ // The monitor thread will pick up the process exit and mark the action as Failed.
+ return Sent;
+}
+
+void
+WindowsProcessRunner::SampleProcessCpu(RunningAction& Running)
+{
+ FILETIME CreationTime, ExitTime, KernelTime, UserTime;
+ if (!GetProcessTimes(Running.ProcessHandle, &CreationTime, &ExitTime, &KernelTime, &UserTime))
+ {
+ return;
+ }
+
+ auto FtToU64 = [](FILETIME Ft) -> uint64_t { return (static_cast<uint64_t>(Ft.dwHighDateTime) << 32) | Ft.dwLowDateTime; };
+
+ // FILETIME values are in 100-nanosecond intervals
+ const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime);
+ const uint64_t NowTicks = GetHifreqTimerValue();
+
+ // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds
+ Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed);
+
+ if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
+ {
+ const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks);
+ if (ElapsedMs > 0)
+ {
+ const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
+ // 100ns → ms: divide by 10000; then as percent of elapsed ms
+ const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0);
+ Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
+ }
+ }
+
+ Running.LastCpuSampleTicks = NowTicks;
+ Running.LastCpuOsTicks = CurrentOsTicks;
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h
new file mode 100644
index 000000000..9f2385cc4
--- /dev/null
+++ b/src/zencompute/runners/windowsrunner.h
@@ -0,0 +1,53 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS
+
+# include <zencore/windows.h>
+
+# include <string>
+
+namespace zen::compute {
+
+/** Windows process runner using CreateProcessW for executing worker executables.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and monitor thread infrastructure. Overrides only the
+ platform-specific methods: process spawning, sweep, and cancellation.
+
+ When Sandboxed is true, child processes are isolated using a Windows AppContainer:
+ no network access (AppContainer blocks network by default when no capabilities are
+ granted) and no filesystem access outside explicitly granted sandbox and worker
+ directories. This requires no elevation.
+ */
+class WindowsProcessRunner : public LocalProcessRunner
+{
+public:
+ WindowsProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ bool Sandboxed = false,
+ int32_t MaxConcurrentActions = 0);
+ ~WindowsProcessRunner();
+
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void SweepRunningActions() override;
+ void CancelRunningActions() override;
+ bool CancelAction(int ActionLsn) override;
+ void SampleProcessCpu(RunningAction& Running) override;
+
+private:
+ void GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask);
+
+ bool m_Sandboxed = false;
+ PSID m_AppContainerSid = nullptr;
+ std::wstring m_AppContainerName;
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp
new file mode 100644
index 000000000..506bec73b
--- /dev/null
+++ b/src/zencompute/runners/winerunner.cpp
@@ -0,0 +1,237 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "winerunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+
+# include <signal.h>
+# include <sys/wait.h>
+# include <unistd.h>
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool)
+{
+ // Restore SIGCHLD to default behavior so waitpid() can properly collect
+ // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which
+ // causes the kernel to auto-reap children, making waitpid() return
+ // -1/ECHILD instead of the exit status we need.
+ struct sigaction Action = {};
+ sigemptyset(&Action.sa_mask);
+ Action.sa_handler = SIG_DFL;
+ sigaction(SIGCHLD, &Action, nullptr);
+}
+
+SubmitResult
+WineProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("WineProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ // Build environment array from worker descriptor
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ std::vector<std::string> EnvStrings;
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ EnvStrings.emplace_back(It.AsString());
+ }
+
+ std::vector<char*> Envp;
+ Envp.reserve(EnvStrings.size() + 1);
+ for (auto& Str : EnvStrings)
+ {
+ Envp.push_back(Str.data());
+ }
+ Envp.push_back(nullptr);
+
+ // Build argv: wine <worker_exe_path> -Build=build.action
+
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath);
+ std::string ExePathStr = ExePath.string();
+ std::string WinePathStr = m_WinePath;
+ std::string BuildArg = "-Build=build.action";
+
+ std::vector<char*> ArgV;
+ ArgV.push_back(WinePathStr.data());
+ ArgV.push_back(ExePathStr.data());
+ ArgV.push_back(BuildArg.data());
+ ArgV.push_back(nullptr);
+
+ ZEN_DEBUG("Executing via Wine: {} {} {}", WinePathStr, ExePathStr, BuildArg);
+
+ std::string SandboxPathStr = Prepared->SandboxPath.string();
+
+ pid_t ChildPid = fork();
+
+ if (ChildPid < 0)
+ {
+ throw std::runtime_error(fmt::format("fork() failed: {}", strerror(errno)));
+ }
+
+ if (ChildPid == 0)
+ {
+ // Child process
+ if (chdir(SandboxPathStr.c_str()) != 0)
+ {
+ _exit(127);
+ }
+
+ execve(WinePathStr.c_str(), ArgV.data(), Envp.data());
+
+ // execve only returns on failure
+ _exit(127);
+ }
+
+ // Parent: store child pid as void* (same convention as zencore/process.cpp)
+
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid));
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+WineProcessRunner::SweepRunningActions()
+{
+ ZEN_TRACE_CPU("WineProcessRunner::SweepRunningActions");
+ std::vector<Ref<RunningAction>> CompletedActions;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;)
+ {
+ Ref<RunningAction> Running = It->second;
+
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+ int Status = 0;
+
+ pid_t Result = waitpid(Pid, &Status, WNOHANG);
+
+ if (Result == Pid)
+ {
+ if (WIFEXITED(Status))
+ {
+ Running->ExitCode = WEXITSTATUS(Status);
+ }
+ else if (WIFSIGNALED(Status))
+ {
+ Running->ExitCode = 128 + WTERMSIG(Status);
+ }
+ else
+ {
+ Running->ExitCode = 1;
+ }
+
+ Running->ProcessHandle = nullptr;
+
+ CompletedActions.push_back(std::move(Running));
+ It = m_RunningMap.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+WineProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("WineProcessRunner::CancelRunningActions");
+ Stopwatch Timer;
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling all running actions");
+
+ // Send SIGTERM to all running processes first
+
+ for (const auto& [Lsn, Running] : RunningMap)
+ {
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+
+ if (kill(Pid, SIGTERM) != 0)
+ {
+ ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno));
+ }
+ }
+
+ // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up.
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle));
+
+ // Poll for up to 2 seconds
+ bool Exited = false;
+ for (int i = 0; i < 20; ++i)
+ {
+ int Status = 0;
+ pid_t WaitResult = waitpid(Pid, &Status, WNOHANG);
+ if (WaitResult == Pid)
+ {
+ Exited = true;
+ ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn);
+ break;
+ }
+ usleep(100000); // 100ms
+ }
+
+ if (!Exited)
+ {
+ ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn);
+ kill(Pid, SIGKILL);
+ waitpid(Pid, nullptr, 0);
+ }
+
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/winerunner.h b/src/zencompute/runners/winerunner.h
new file mode 100644
index 000000000..7df62e7c0
--- /dev/null
+++ b/src/zencompute/runners/winerunner.h
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX
+
+# include <string>
+
+namespace zen::compute {
+
+/** Wine-based process runner for executing Windows worker executables on Linux.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and monitor thread infrastructure. Overrides only the
+ platform-specific methods: process spawning, sweep, and cancellation.
+ */
+class WineProcessRunner : public LocalProcessRunner
+{
+public:
+ WineProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool);
+
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void SweepRunningActions() override;
+ void CancelRunningActions() override;
+
+private:
+ std::string m_WinePath = "wine";
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp
new file mode 100644
index 000000000..dd09312df
--- /dev/null
+++ b/src/zencompute/testing/mockimds.cpp
@@ -0,0 +1,205 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencompute/mockimds.h>
+
+#include <zencore/fmtutils.h>
+
+#if ZEN_WITH_TESTS
+
+namespace zen::compute {
+
+const char*
+MockImdsService::BaseUri() const
+{
+ return "/";
+}
+
+void
+MockImdsService::HandleRequest(HttpServerRequest& Request)
+{
+ std::string_view Uri = Request.RelativeUri();
+
+ // AWS endpoints live under /latest/
+ if (Uri.starts_with("latest/"))
+ {
+ if (ActiveProvider == CloudProvider::AWS)
+ {
+ HandleAwsRequest(Request);
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ return;
+ }
+
+ // Azure endpoints live under /metadata/
+ if (Uri.starts_with("metadata/"))
+ {
+ if (ActiveProvider == CloudProvider::Azure)
+ {
+ HandleAzureRequest(Request);
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ return;
+ }
+
+ // GCP endpoints live under /computeMetadata/
+ if (Uri.starts_with("computeMetadata/"))
+ {
+ if (ActiveProvider == CloudProvider::GCP)
+ {
+ HandleGcpRequest(Request);
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ return;
+ }
+
+ Request.WriteResponse(HttpResponseCode::NotFound);
+}
+
+// ---------------------------------------------------------------------------
+// AWS
+// ---------------------------------------------------------------------------
+
+void
+MockImdsService::HandleAwsRequest(HttpServerRequest& Request)
+{
+ std::string_view Uri = Request.RelativeUri();
+
+ // IMDSv2 token acquisition (PUT only)
+ if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token);
+ return;
+ }
+
+ // Instance identity
+ if (Uri == "latest/meta-data/instance-id")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId);
+ return;
+ }
+
+ if (Uri == "latest/meta-data/placement/availability-zone")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone);
+ return;
+ }
+
+ if (Uri == "latest/meta-data/instance-life-cycle")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle);
+ return;
+ }
+
+ // Autoscaling lifecycle state — 404 when not in an ASG
+ if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state")
+ {
+ if (Aws.AutoscalingState.empty())
+ {
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState);
+ return;
+ }
+
+ // Spot interruption notice — 404 when no interruption pending
+ if (Uri == "latest/meta-data/spot/instance-action")
+ {
+ if (Aws.SpotAction.empty())
+ {
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction);
+ return;
+ }
+
+ Request.WriteResponse(HttpResponseCode::NotFound);
+}
+
+// ---------------------------------------------------------------------------
+// Azure
+// ---------------------------------------------------------------------------
+
+void
+MockImdsService::HandleAzureRequest(HttpServerRequest& Request)
+{
+ std::string_view Uri = Request.RelativeUri();
+
+ // Instance metadata (single JSON document)
+ if (Uri == "metadata/instance")
+ {
+ std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})",
+ Azure.VmId,
+ Azure.Location,
+ Azure.Priority,
+ Azure.VmScaleSetName);
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json);
+ return;
+ }
+
+ // Scheduled events for termination monitoring
+ if (Uri == "metadata/scheduledevents")
+ {
+ std::string Json;
+ if (Azure.ScheduledEventType.empty())
+ {
+ Json = R"({"Events":[]})";
+ }
+ else
+ {
+ Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})",
+ Azure.ScheduledEventType,
+ Azure.ScheduledEventStatus);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json);
+ return;
+ }
+
+ Request.WriteResponse(HttpResponseCode::NotFound);
+}
+
+// ---------------------------------------------------------------------------
+// GCP
+// ---------------------------------------------------------------------------
+
+void
+MockImdsService::HandleGcpRequest(HttpServerRequest& Request)
+{
+ std::string_view Uri = Request.RelativeUri();
+
+ if (Uri == "computeMetadata/v1/instance/id")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId);
+ return;
+ }
+
+ if (Uri == "computeMetadata/v1/instance/zone")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone);
+ return;
+ }
+
+ if (Uri == "computeMetadata/v1/instance/scheduling/preemptible")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible);
+ return;
+ }
+
+ if (Uri == "computeMetadata/v1/instance/maintenance-event")
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent);
+ return;
+ }
+
+ Request.WriteResponse(HttpResponseCode::NotFound);
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zencompute/timeline/workertimeline.cpp b/src/zencompute/timeline/workertimeline.cpp
new file mode 100644
index 000000000..88ef5b62d
--- /dev/null
+++ b/src/zencompute/timeline/workertimeline.cpp
@@ -0,0 +1,430 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "workertimeline.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/basicfile.h>
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinaryfile.h>
+
+# include <algorithm>
+
+namespace zen::compute {
+
+WorkerTimeline::WorkerTimeline(std::string_view WorkerId) : m_WorkerId(WorkerId)
+{
+}
+
+WorkerTimeline::~WorkerTimeline()
+{
+}
+
+void
+WorkerTimeline::RecordProvisioned()
+{
+ AppendEvent({
+ .Type = EventType::WorkerProvisioned,
+ .Timestamp = DateTime::Now(),
+ });
+}
+
+void
+WorkerTimeline::RecordDeprovisioned()
+{
+ AppendEvent({
+ .Type = EventType::WorkerDeprovisioned,
+ .Timestamp = DateTime::Now(),
+ });
+}
+
+void
+WorkerTimeline::RecordActionAccepted(int ActionLsn, const IoHash& ActionId)
+{
+ AppendEvent({
+ .Type = EventType::ActionAccepted,
+ .Timestamp = DateTime::Now(),
+ .ActionLsn = ActionLsn,
+ .ActionId = ActionId,
+ });
+}
+
+void
+WorkerTimeline::RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason)
+{
+ AppendEvent({
+ .Type = EventType::ActionRejected,
+ .Timestamp = DateTime::Now(),
+ .ActionLsn = ActionLsn,
+ .ActionId = ActionId,
+ .Reason = std::string(Reason),
+ });
+}
+
+void
+WorkerTimeline::RecordActionStateChanged(int ActionLsn,
+ const IoHash& ActionId,
+ RunnerAction::State PreviousState,
+ RunnerAction::State NewState)
+{
+ AppendEvent({
+ .Type = EventType::ActionStateChanged,
+ .Timestamp = DateTime::Now(),
+ .ActionLsn = ActionLsn,
+ .ActionId = ActionId,
+ .ActionState = NewState,
+ .PreviousState = PreviousState,
+ });
+}
+
+std::vector<WorkerTimeline::Event>
+WorkerTimeline::QueryTimeline(DateTime StartTime, DateTime EndTime) const
+{
+ std::vector<Event> Result;
+
+ m_EventsLock.WithSharedLock([&] {
+ for (const auto& Evt : m_Events)
+ {
+ if (Evt.Timestamp >= StartTime && Evt.Timestamp <= EndTime)
+ {
+ Result.push_back(Evt);
+ }
+ }
+ });
+
+ return Result;
+}
+
+std::vector<WorkerTimeline::Event>
+WorkerTimeline::QueryRecent(int Limit) const
+{
+ std::vector<Event> Result;
+
+ m_EventsLock.WithSharedLock([&] {
+ const int Count = std::min(Limit, gsl::narrow<int>(m_Events.size()));
+ auto It = m_Events.end() - Count;
+ Result.assign(It, m_Events.end());
+ });
+
+ return Result;
+}
+
+size_t
+WorkerTimeline::GetEventCount() const
+{
+ size_t Count = 0;
+ m_EventsLock.WithSharedLock([&] { Count = m_Events.size(); });
+ return Count;
+}
+
+WorkerTimeline::TimeRange
+WorkerTimeline::GetTimeRange() const
+{
+ TimeRange Range;
+ m_EventsLock.WithSharedLock([&] {
+ if (!m_Events.empty())
+ {
+ Range.First = m_Events.front().Timestamp;
+ Range.Last = m_Events.back().Timestamp;
+ }
+ });
+ return Range;
+}
+
+void
+WorkerTimeline::AppendEvent(Event&& Evt)
+{
+ m_EventsLock.WithExclusiveLock([&] {
+ while (m_Events.size() >= m_MaxEvents)
+ {
+ m_Events.pop_front();
+ }
+
+ m_Events.push_back(std::move(Evt));
+ });
+}
+
+const char*
+WorkerTimeline::ToString(EventType Type)
+{
+ switch (Type)
+ {
+ case EventType::WorkerProvisioned:
+ return "provisioned";
+ case EventType::WorkerDeprovisioned:
+ return "deprovisioned";
+ case EventType::ActionAccepted:
+ return "accepted";
+ case EventType::ActionRejected:
+ return "rejected";
+ case EventType::ActionStateChanged:
+ return "state_changed";
+ default:
+ return "unknown";
+ }
+}
+
+static WorkerTimeline::EventType
+EventTypeFromString(std::string_view Str)
+{
+ if (Str == "provisioned")
+ return WorkerTimeline::EventType::WorkerProvisioned;
+ if (Str == "deprovisioned")
+ return WorkerTimeline::EventType::WorkerDeprovisioned;
+ if (Str == "accepted")
+ return WorkerTimeline::EventType::ActionAccepted;
+ if (Str == "rejected")
+ return WorkerTimeline::EventType::ActionRejected;
+ if (Str == "state_changed")
+ return WorkerTimeline::EventType::ActionStateChanged;
+ return WorkerTimeline::EventType::WorkerProvisioned;
+}
+
+void
+WorkerTimeline::WriteTo(const std::filesystem::path& Path) const
+{
+ CbObjectWriter Cbo;
+ Cbo << "worker_id" << m_WorkerId;
+
+ m_EventsLock.WithSharedLock([&] {
+ if (!m_Events.empty())
+ {
+ Cbo.AddDateTime("time_first", m_Events.front().Timestamp);
+ Cbo.AddDateTime("time_last", m_Events.back().Timestamp);
+ }
+
+ Cbo.BeginArray("events");
+ for (const auto& Evt : m_Events)
+ {
+ Cbo.BeginObject();
+ Cbo << "type" << ToString(Evt.Type);
+ Cbo.AddDateTime("ts", Evt.Timestamp);
+
+ if (Evt.ActionLsn != 0)
+ {
+ Cbo << "lsn" << Evt.ActionLsn;
+ Cbo << "action_id" << Evt.ActionId;
+ }
+
+ if (Evt.Type == EventType::ActionStateChanged)
+ {
+ Cbo << "prev_state" << static_cast<int32_t>(Evt.PreviousState);
+ Cbo << "state" << static_cast<int32_t>(Evt.ActionState);
+ }
+
+ if (!Evt.Reason.empty())
+ {
+ Cbo << "reason" << std::string_view(Evt.Reason);
+ }
+
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+ });
+
+ CbObject Obj = Cbo.Save();
+
+ BasicFile File(Path, BasicFile::Mode::kTruncate);
+ File.Write(Obj.GetBuffer().GetView(), 0);
+}
+
+void
+WorkerTimeline::ReadFrom(const std::filesystem::path& Path)
+{
+ CbObjectFromFile Loaded = LoadCompactBinaryObject(Path);
+ CbObject Root = std::move(Loaded.Object);
+
+ if (!Root)
+ {
+ return;
+ }
+
+ std::deque<Event> LoadedEvents;
+
+ for (CbFieldView Field : Root["events"].AsArrayView())
+ {
+ CbObjectView EventObj = Field.AsObjectView();
+
+ Event Evt;
+ Evt.Type = EventTypeFromString(EventObj["type"].AsString());
+ Evt.Timestamp = EventObj["ts"].AsDateTime();
+
+ Evt.ActionLsn = EventObj["lsn"].AsInt32();
+ Evt.ActionId = EventObj["action_id"].AsHash();
+
+ if (Evt.Type == EventType::ActionStateChanged)
+ {
+ Evt.PreviousState = static_cast<RunnerAction::State>(EventObj["prev_state"].AsInt32());
+ Evt.ActionState = static_cast<RunnerAction::State>(EventObj["state"].AsInt32());
+ }
+
+ std::string_view Reason = EventObj["reason"].AsString();
+ if (!Reason.empty())
+ {
+ Evt.Reason = std::string(Reason);
+ }
+
+ LoadedEvents.push_back(std::move(Evt));
+ }
+
+ m_EventsLock.WithExclusiveLock([&] { m_Events = std::move(LoadedEvents); });
+}
+
+WorkerTimeline::TimeRange
+WorkerTimeline::ReadTimeRange(const std::filesystem::path& Path)
+{
+ CbObjectFromFile Loaded = LoadCompactBinaryObject(Path);
+
+ if (!Loaded.Object)
+ {
+ return {};
+ }
+
+ return {
+ .First = Loaded.Object["time_first"].AsDateTime(),
+ .Last = Loaded.Object["time_last"].AsDateTime(),
+ };
+}
+
+// WorkerTimelineStore
+
+static constexpr std::string_view kTimelineExtension = ".ztimeline";
+
+WorkerTimelineStore::WorkerTimelineStore(std::filesystem::path PersistenceDir) : m_PersistenceDir(std::move(PersistenceDir))
+{
+ std::error_code Ec;
+ std::filesystem::create_directories(m_PersistenceDir, Ec);
+}
+
+Ref<WorkerTimeline>
+WorkerTimelineStore::GetOrCreate(std::string_view WorkerId)
+{
+ // Fast path: check if it already exists in memory
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ auto It = m_Timelines.find(std::string(WorkerId));
+ if (It != m_Timelines.end())
+ {
+ return It->second;
+ }
+ }
+
+ // Slow path: create under exclusive lock, loading from disk if available
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ auto& Entry = m_Timelines[std::string(WorkerId)];
+ if (!Entry)
+ {
+ Entry = Ref<WorkerTimeline>(new WorkerTimeline(WorkerId));
+
+ std::filesystem::path Path = TimelinePath(WorkerId);
+ std::error_code Ec;
+ if (std::filesystem::is_regular_file(Path, Ec))
+ {
+ Entry->ReadFrom(Path);
+ }
+ }
+ return Entry;
+}
+
+Ref<WorkerTimeline>
+WorkerTimelineStore::Find(std::string_view WorkerId)
+{
+ RwLock::SharedLockScope _(m_Lock);
+ auto It = m_Timelines.find(std::string(WorkerId));
+ if (It != m_Timelines.end())
+ {
+ return It->second;
+ }
+ return {};
+}
+
+std::vector<std::string>
+WorkerTimelineStore::GetActiveWorkerIds() const
+{
+ std::vector<std::string> Result;
+
+ RwLock::SharedLockScope $(m_Lock);
+ Result.reserve(m_Timelines.size());
+ for (const auto& [Id, _] : m_Timelines)
+ {
+ Result.push_back(Id);
+ }
+
+ return Result;
+}
+
+std::vector<WorkerTimelineStore::WorkerTimelineInfo>
+WorkerTimelineStore::GetAllWorkerInfo() const
+{
+ std::unordered_map<std::string, WorkerTimeline::TimeRange> InfoMap;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ for (const auto& [Id, Timeline] : m_Timelines)
+ {
+ InfoMap[Id] = Timeline->GetTimeRange();
+ }
+ }
+
+ std::error_code Ec;
+ for (const auto& Entry : std::filesystem::directory_iterator(m_PersistenceDir, Ec))
+ {
+ if (!Entry.is_regular_file())
+ {
+ continue;
+ }
+
+ const auto& Path = Entry.path();
+ if (Path.extension().string() != kTimelineExtension)
+ {
+ continue;
+ }
+
+ std::string Id = Path.stem().string();
+ if (InfoMap.find(Id) == InfoMap.end())
+ {
+ InfoMap[Id] = WorkerTimeline::ReadTimeRange(Path);
+ }
+ }
+
+ std::vector<WorkerTimelineInfo> Result;
+ Result.reserve(InfoMap.size());
+ for (auto& [Id, Range] : InfoMap)
+ {
+ Result.push_back({.WorkerId = std::move(Id), .Range = Range});
+ }
+ return Result;
+}
+
+void
+WorkerTimelineStore::Save(std::string_view WorkerId)
+{
+ RwLock::SharedLockScope _(m_Lock);
+ auto It = m_Timelines.find(std::string(WorkerId));
+ if (It != m_Timelines.end())
+ {
+ It->second->WriteTo(TimelinePath(WorkerId));
+ }
+}
+
+void
+WorkerTimelineStore::SaveAll()
+{
+ RwLock::SharedLockScope _(m_Lock);
+ for (const auto& [Id, Timeline] : m_Timelines)
+ {
+ Timeline->WriteTo(TimelinePath(Id));
+ }
+}
+
+std::filesystem::path
+WorkerTimelineStore::TimelinePath(std::string_view WorkerId) const
+{
+ return m_PersistenceDir / (std::string(WorkerId) + std::string(kTimelineExtension));
+}
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/timeline/workertimeline.h b/src/zencompute/timeline/workertimeline.h
new file mode 100644
index 000000000..87e19bc28
--- /dev/null
+++ b/src/zencompute/timeline/workertimeline.h
@@ -0,0 +1,169 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../runners/functionrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zenbase/refcount.h>
+# include <zencore/compactbinary.h>
+# include <zencore/iohash.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+
+# include <deque>
+# include <filesystem>
+# include <string>
+# include <string_view>
+# include <unordered_map>
+# include <vector>
+
+namespace zen::compute {
+
+struct RunnerAction;
+
+/** Worker activity timeline for tracking and visualizing worker activity over time.
+ *
+ * Records worker lifecycle events (provisioning/deprovisioning) and action lifecycle
+ * events (accept, reject, state changes) with timestamps, enabling time-range queries
+ * for dashboard visualization.
+ */
+class WorkerTimeline : public RefCounted
+{
+public:
+ explicit WorkerTimeline(std::string_view WorkerId);
+ ~WorkerTimeline() override;
+
+ struct TimeRange
+ {
+ DateTime First = DateTime(0);
+ DateTime Last = DateTime(0);
+
+ explicit operator bool() const { return First.GetTicks() != 0; }
+ };
+
+ enum class EventType
+ {
+ WorkerProvisioned,
+ WorkerDeprovisioned,
+ ActionAccepted,
+ ActionRejected,
+ ActionStateChanged
+ };
+
+ static const char* ToString(EventType Type);
+
+ struct Event
+ {
+ EventType Type;
+ DateTime Timestamp = DateTime(0);
+
+ // Action context (only set for action events)
+ int ActionLsn = 0;
+ IoHash ActionId;
+ RunnerAction::State ActionState = RunnerAction::State::New;
+ RunnerAction::State PreviousState = RunnerAction::State::New;
+
+ // Optional reason (e.g. rejection reason)
+ std::string Reason;
+ };
+
+ /** Record that this worker has been provisioned and is available for work. */
+ void RecordProvisioned();
+
+ /** Record that this worker has been deprovisioned and is no longer available. */
+ void RecordDeprovisioned();
+
+ /** Record that an action was accepted by this worker. */
+ void RecordActionAccepted(int ActionLsn, const IoHash& ActionId);
+
+ /** Record that an action was rejected by this worker. */
+ void RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason);
+
+ /** Record an action state transition on this worker. */
+ void RecordActionStateChanged(int ActionLsn, const IoHash& ActionId, RunnerAction::State PreviousState, RunnerAction::State NewState);
+
+ /** Query events within a time range (inclusive). Returns events ordered by timestamp. */
+ [[nodiscard]] std::vector<Event> QueryTimeline(DateTime StartTime, DateTime EndTime) const;
+
+ /** Query the most recent N events. */
+ [[nodiscard]] std::vector<Event> QueryRecent(int Limit = 100) const;
+
+ /** Return the total number of recorded events. */
+ [[nodiscard]] size_t GetEventCount() const;
+
+ /** Return the time range covered by the events in this timeline. */
+ [[nodiscard]] TimeRange GetTimeRange() const;
+
+ [[nodiscard]] const std::string& GetWorkerId() const { return m_WorkerId; }
+
+ /** Write the timeline to a file at the given path. */
+ void WriteTo(const std::filesystem::path& Path) const;
+
+ /** Read the timeline from a file at the given path. Replaces current in-memory events. */
+ void ReadFrom(const std::filesystem::path& Path);
+
+ /** Read only the time range from a persisted timeline file, without loading events. */
+ [[nodiscard]] static TimeRange ReadTimeRange(const std::filesystem::path& Path);
+
+private:
+ void AppendEvent(Event&& Evt);
+
+ std::string m_WorkerId;
+ mutable RwLock m_EventsLock;
+ std::deque<Event> m_Events;
+ size_t m_MaxEvents = 10'000;
+};
+
+/** Manages a set of WorkerTimeline instances, keyed by worker ID.
+ *
+ * Provides thread-safe lookup and on-demand creation of timelines, backed by
+ * a persistence directory. Each timeline is stored as a separate file named
+ * {WorkerId}.ztimeline within the directory.
+ */
+class WorkerTimelineStore
+{
+public:
+ explicit WorkerTimelineStore(std::filesystem::path PersistenceDir);
+ ~WorkerTimelineStore() = default;
+
+ WorkerTimelineStore(const WorkerTimelineStore&) = delete;
+ WorkerTimelineStore& operator=(const WorkerTimelineStore&) = delete;
+
+ /** Get the timeline for a worker, creating one if it does not exist.
+ * If a persisted file exists on disk it will be loaded on first access. */
+ Ref<WorkerTimeline> GetOrCreate(std::string_view WorkerId);
+
+ /** Get the timeline for a worker, or null ref if it does not exist in memory. */
+ [[nodiscard]] Ref<WorkerTimeline> Find(std::string_view WorkerId);
+
+ /** Return the worker IDs of currently loaded (in-memory) timelines. */
+ [[nodiscard]] std::vector<std::string> GetActiveWorkerIds() const;
+
+ struct WorkerTimelineInfo
+ {
+ std::string WorkerId;
+ WorkerTimeline::TimeRange Range;
+ };
+
+ /** Return info for all known timelines (in-memory and on-disk), including time range. */
+ [[nodiscard]] std::vector<WorkerTimelineInfo> GetAllWorkerInfo() const;
+
+ /** Persist a single worker's timeline to disk. */
+ void Save(std::string_view WorkerId);
+
+ /** Persist all in-memory timelines to disk. */
+ void SaveAll();
+
+private:
+ [[nodiscard]] std::filesystem::path TimelinePath(std::string_view WorkerId) const;
+
+ std::filesystem::path m_PersistenceDir;
+ mutable RwLock m_Lock;
+ std::unordered_map<std::string, Ref<WorkerTimeline>> m_Timelines;
+};
+
+} // namespace zen::compute
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua
new file mode 100644
index 000000000..ed0af66a5
--- /dev/null
+++ b/src/zencompute/xmake.lua
@@ -0,0 +1,19 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zencompute')
+ set_kind("static")
+ set_group("libs")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_includedirs(".", {private=true})
+ add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp")
+ add_packages("json11")
+
+ if is_os("macosx") then
+ add_cxxflags("-Wno-deprecated-declarations")
+ end
+
+ if is_plat("windows") then
+ add_syslinks("Userenv")
+ end
diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp
new file mode 100644
index 000000000..1f3f6d3f9
--- /dev/null
+++ b/src/zencompute/zencompute.cpp
@@ -0,0 +1,21 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencompute/zencompute.h"
+
+#if ZEN_WITH_TESTS
+# include "runners/deferreddeleter.h"
+# include <zencompute/cloudmetadata.h>
+#endif
+
+namespace zen {
+
+void
+zencompute_forcelinktests()
+{
+#if ZEN_WITH_TESTS
+ compute::cloudmetadata_forcelink();
+ compute::deferreddeleter_forcelink();
+#endif
+}
+
+} // namespace zen
diff --git a/src/zencore-test/zencore-test.cpp b/src/zencore-test/zencore-test.cpp
index 68fc940ee..3d9a79283 100644
--- a/src/zencore-test/zencore-test.cpp
+++ b/src/zencore-test/zencore-test.cpp
@@ -1,47 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-// zencore-test.cpp : Defines the entry point for the console application.
-//
-
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/trace.h>
+#include <zencore/testing.h>
#include <zencore/zencore.h>
#include <zencore/memory/newdelete.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
-
int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zencore_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zencore-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zencore-test", zen::zencore_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zencore/base64.cpp b/src/zencore/base64.cpp
index 1f56ee6c3..96e121799 100644
--- a/src/zencore/base64.cpp
+++ b/src/zencore/base64.cpp
@@ -1,6 +1,10 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include <zencore/base64.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#include <string>
namespace zen {
@@ -11,7 +15,6 @@ static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G',
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
/** The table used to convert an ascii character into a 6 bit value */
-#if 0
static const uint8_t DecodingAlphabet[256] = {
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00-0x0f
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10-0x1f
@@ -30,7 +33,6 @@ static const uint8_t DecodingAlphabet[256] = {
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xe0-0xef
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // 0xf0-0xff
};
-#endif // 0
template<typename CharType>
uint32_t
@@ -104,4 +106,194 @@ Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest)
template uint32_t Base64::Encode<char>(const uint8_t* Source, uint32_t Length, char* Dest);
template uint32_t Base64::Encode<wchar_t>(const uint8_t* Source, uint32_t Length, wchar_t* Dest);
+template<typename CharType>
+bool
+Base64::Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength)
+{
+ // Length must be a multiple of 4
+ if (Length % 4 != 0)
+ {
+ OutLength = 0;
+ return false;
+ }
+
+ uint8_t* DecodedBytes = Dest;
+
+ // Process 4 encoded characters at a time, producing 3 decoded bytes
+ while (Length > 0)
+ {
+ // Count padding characters at the end
+ uint32_t PadCount = 0;
+ if (Source[3] == '=')
+ {
+ PadCount++;
+ if (Source[2] == '=')
+ {
+ PadCount++;
+ }
+ }
+
+ // Look up each character in the decoding table
+ uint8_t A = DecodingAlphabet[static_cast<uint8_t>(Source[0])];
+ uint8_t B = DecodingAlphabet[static_cast<uint8_t>(Source[1])];
+ uint8_t C = (PadCount >= 2) ? 0 : DecodingAlphabet[static_cast<uint8_t>(Source[2])];
+ uint8_t D = (PadCount >= 1) ? 0 : DecodingAlphabet[static_cast<uint8_t>(Source[3])];
+
+ // Check for invalid characters (0xFF means not in the base64 alphabet)
+ if (A == 0xFF || B == 0xFF || C == 0xFF || D == 0xFF)
+ {
+ OutLength = 0;
+ return false;
+ }
+
+ // Reconstruct the 24-bit value from 4 6-bit chunks
+ uint32_t ByteTriplet = (A << 18) | (B << 12) | (C << 6) | D;
+
+ // Extract the 3 bytes
+ *DecodedBytes++ = static_cast<uint8_t>(ByteTriplet >> 16);
+ if (PadCount < 2)
+ {
+ *DecodedBytes++ = static_cast<uint8_t>((ByteTriplet >> 8) & 0xFF);
+ }
+ if (PadCount < 1)
+ {
+ *DecodedBytes++ = static_cast<uint8_t>(ByteTriplet & 0xFF);
+ }
+
+ Source += 4;
+ Length -= 4;
+ }
+
+ OutLength = uint32_t(DecodedBytes - Dest);
+ return true;
+}
+
+template bool Base64::Decode<char>(const char* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength);
+template bool Base64::Decode<wchar_t>(const wchar_t* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength);
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+using namespace std::string_literals;
+
+TEST_SUITE_BEGIN("core.base64");
+
+TEST_CASE("Base64")
+{
+ auto EncodeString = [](std::string_view Input) -> std::string {
+ std::string Result;
+ Result.resize(Base64::GetEncodedDataSize(uint32_t(Input.size())));
+ Base64::Encode(reinterpret_cast<const uint8_t*>(Input.data()), uint32_t(Input.size()), Result.data());
+ return Result;
+ };
+
+ auto DecodeString = [](std::string_view Input) -> std::string {
+ std::string Result;
+ Result.resize(Base64::GetMaxDecodedDataSize(uint32_t(Input.size())));
+ uint32_t DecodedLength = 0;
+ bool Success = Base64::Decode(Input.data(), uint32_t(Input.size()), reinterpret_cast<uint8_t*>(Result.data()), DecodedLength);
+ CHECK(Success);
+ Result.resize(DecodedLength);
+ return Result;
+ };
+
+ SUBCASE("Encode")
+ {
+ CHECK(EncodeString("") == ""s);
+ CHECK(EncodeString("f") == "Zg=="s);
+ CHECK(EncodeString("fo") == "Zm8="s);
+ CHECK(EncodeString("foo") == "Zm9v"s);
+ CHECK(EncodeString("foob") == "Zm9vYg=="s);
+ CHECK(EncodeString("fooba") == "Zm9vYmE="s);
+ CHECK(EncodeString("foobar") == "Zm9vYmFy"s);
+ }
+
+ SUBCASE("Decode")
+ {
+ CHECK(DecodeString("") == ""s);
+ CHECK(DecodeString("Zg==") == "f"s);
+ CHECK(DecodeString("Zm8=") == "fo"s);
+ CHECK(DecodeString("Zm9v") == "foo"s);
+ CHECK(DecodeString("Zm9vYg==") == "foob"s);
+ CHECK(DecodeString("Zm9vYmE=") == "fooba"s);
+ CHECK(DecodeString("Zm9vYmFy") == "foobar"s);
+ }
+
+ SUBCASE("RoundTrip")
+ {
+ auto RoundTrip = [&](const std::string& Input) {
+ std::string Encoded = EncodeString(Input);
+ std::string Decoded = DecodeString(Encoded);
+ CHECK(Decoded == Input);
+ };
+
+ RoundTrip("Hello, World!");
+ RoundTrip("Base64 encoding test with various lengths");
+ RoundTrip("A");
+ RoundTrip("AB");
+ RoundTrip("ABC");
+ RoundTrip("ABCD");
+ RoundTrip("\x00\x01\x02\xff\xfe\xfd"s);
+ }
+
+ SUBCASE("BinaryRoundTrip")
+ {
+ // Test with all byte values 0-255
+ uint8_t AllBytes[256];
+ for (int i = 0; i < 256; ++i)
+ {
+ AllBytes[i] = static_cast<uint8_t>(i);
+ }
+
+ char Encoded[Base64::GetEncodedDataSize(256) + 1];
+ Base64::Encode(AllBytes, 256, Encoded);
+
+ uint8_t Decoded[256];
+ uint32_t DecodedLength = 0;
+ bool Success = Base64::Decode(Encoded, uint32_t(strlen(Encoded)), Decoded, DecodedLength);
+ CHECK(Success);
+ CHECK(DecodedLength == 256);
+ CHECK(memcmp(AllBytes, Decoded, 256) == 0);
+ }
+
+ SUBCASE("DecodeInvalidInput")
+ {
+ uint8_t Dest[64];
+ uint32_t OutLength = 0;
+
+ // Length not a multiple of 4
+ CHECK_FALSE(Base64::Decode("abc", 3u, Dest, OutLength));
+
+ // Invalid character
+ CHECK_FALSE(Base64::Decode("ab!d", 4u, Dest, OutLength));
+ }
+
+ SUBCASE("EncodedDataSize")
+ {
+ CHECK(Base64::GetEncodedDataSize(0) == 0);
+ CHECK(Base64::GetEncodedDataSize(1) == 4);
+ CHECK(Base64::GetEncodedDataSize(2) == 4);
+ CHECK(Base64::GetEncodedDataSize(3) == 4);
+ CHECK(Base64::GetEncodedDataSize(4) == 8);
+ CHECK(Base64::GetEncodedDataSize(5) == 8);
+ CHECK(Base64::GetEncodedDataSize(6) == 8);
+ }
+
+ SUBCASE("MaxDecodedDataSize")
+ {
+ CHECK(Base64::GetMaxDecodedDataSize(0) == 0);
+ CHECK(Base64::GetMaxDecodedDataSize(4) == 3);
+ CHECK(Base64::GetMaxDecodedDataSize(8) == 6);
+ CHECK(Base64::GetMaxDecodedDataSize(12) == 9);
+ }
+}
+
+TEST_SUITE_END();
+
+#endif
+
} // namespace zen
diff --git a/src/zencore/basicfile.cpp b/src/zencore/basicfile.cpp
index bd4d119fb..9dcf7663a 100644
--- a/src/zencore/basicfile.cpp
+++ b/src/zencore/basicfile.cpp
@@ -888,6 +888,8 @@ WriteToTempFile(CompositeBuffer&& Buffer, const std::filesystem::path& Path)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.basicfile");
+
TEST_CASE("BasicFile")
{
ScopedCurrentDirectoryChange _;
@@ -1081,6 +1083,8 @@ TEST_CASE("BasicFileBuffer")
}
}
+TEST_SUITE_END();
+
void
basicfile_forcelink()
{
diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp
index 054f0d3a0..55f9b74af 100644
--- a/src/zencore/blake3.cpp
+++ b/src/zencore/blake3.cpp
@@ -123,7 +123,7 @@ BLAKE3::ToHexString(StringBuilderBase& outBuilder) const
char str[65];
ToHexString(str);
- outBuilder.AppendRange(str, &str[65]);
+ outBuilder.AppendRange(str, &str[StringLength]);
return outBuilder;
}
@@ -200,6 +200,8 @@ BLAKE3Stream::GetHash()
// return text;
// }
+TEST_SUITE_BEGIN("core.blake3");
+
TEST_CASE("BLAKE3")
{
SUBCASE("Basics")
@@ -237,6 +239,8 @@ TEST_CASE("BLAKE3")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/callstack.cpp b/src/zencore/callstack.cpp
index 8aa1111bf..ee0b0625a 100644
--- a/src/zencore/callstack.cpp
+++ b/src/zencore/callstack.cpp
@@ -260,6 +260,8 @@ GetCallstackRaw(void* CaptureBuffer, int FramesToSkip, int FramesToCapture)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.callstack");
+
TEST_CASE("Callstack.Basic")
{
void* Addresses[4];
@@ -272,6 +274,8 @@ TEST_CASE("Callstack.Basic")
}
}
+TEST_SUITE_END();
+
void
callstack_forcelink()
{
diff --git a/src/zencore/commandline.cpp b/src/zencore/commandline.cpp
index 426cf23d6..718ef9678 100644
--- a/src/zencore/commandline.cpp
+++ b/src/zencore/commandline.cpp
@@ -14,6 +14,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
# include <crt_externs.h>
#endif
+#include <locale.h>
#include <functional>
namespace zen {
diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp
index b43cc18f1..9c81305d0 100644
--- a/src/zencore/compactbinary.cpp
+++ b/src/zencore/compactbinary.cpp
@@ -1512,6 +1512,8 @@ uson_forcelink()
{
}
+TEST_SUITE_BEGIN("core.compactbinary");
+
TEST_CASE("guid")
{
using namespace std::literals;
@@ -1704,8 +1706,6 @@ TEST_CASE("uson.datetime")
//////////////////////////////////////////////////////////////////////////
-TEST_SUITE_BEGIN("core.datetime");
-
TEST_CASE("core.datetime.compare")
{
DateTime T1(2000, 12, 13);
@@ -1732,10 +1732,6 @@ TEST_CASE("core.datetime.add")
CHECK(dT + T1 - T2 == dT1);
}
-TEST_SUITE_END();
-
-TEST_SUITE_BEGIN("core.timespan");
-
TEST_CASE("core.timespan.compare")
{
TimeSpan T1(1000);
diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp
index 63c0b9c5c..a9ba30750 100644
--- a/src/zencore/compactbinarybuilder.cpp
+++ b/src/zencore/compactbinarybuilder.cpp
@@ -710,6 +710,8 @@ usonbuilder_forcelink()
// return "";
// }
+TEST_SUITE_BEGIN("core.compactbinarybuilder");
+
TEST_CASE("usonbuilder.object")
{
using namespace std::literals;
@@ -1530,6 +1532,8 @@ TEST_CASE("usonbuilder.stream")
CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
}
}
+
+TEST_SUITE_END();
#endif
} // namespace zen
diff --git a/src/zencore/compactbinaryjson.cpp b/src/zencore/compactbinaryjson.cpp
index abbec360a..da560a449 100644
--- a/src/zencore/compactbinaryjson.cpp
+++ b/src/zencore/compactbinaryjson.cpp
@@ -654,6 +654,8 @@ cbjson_forcelink()
{
}
+TEST_SUITE_BEGIN("core.compactbinaryjson");
+
TEST_CASE("uson.json")
{
using namespace std::literals;
@@ -872,6 +874,8 @@ TEST_CASE("json.uson")
}
}
+TEST_SUITE_END();
+
#endif // ZEN_WITH_TESTS
} // namespace zen
diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp
index ffe64f2e9..56a292ca6 100644
--- a/src/zencore/compactbinarypackage.cpp
+++ b/src/zencore/compactbinarypackage.cpp
@@ -805,6 +805,8 @@ usonpackage_forcelink()
{
}
+TEST_SUITE_BEGIN("core.compactbinarypackage");
+
TEST_CASE("usonpackage")
{
using namespace std::literals;
@@ -1343,6 +1345,8 @@ TEST_CASE("usonpackage.invalidpackage")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/compactbinaryvalidation.cpp b/src/zencore/compactbinaryvalidation.cpp
index d7292f405..3e78f8ef1 100644
--- a/src/zencore/compactbinaryvalidation.cpp
+++ b/src/zencore/compactbinaryvalidation.cpp
@@ -753,10 +753,14 @@ usonvalidation_forcelink()
{
}
+TEST_SUITE_BEGIN("core.compactbinaryvalidation");
+
TEST_CASE("usonvalidation")
{
SUBCASE("Basic") {}
}
+
+TEST_SUITE_END();
#endif
} // namespace zen
diff --git a/src/zencore/compactbinaryyaml.cpp b/src/zencore/compactbinaryyaml.cpp
index 5122e952a..b7f2c55df 100644
--- a/src/zencore/compactbinaryyaml.cpp
+++ b/src/zencore/compactbinaryyaml.cpp
@@ -14,11 +14,6 @@
#include <string_view>
#include <vector>
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <ryml.hpp>
-#include <ryml_std.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
namespace zen {
//////////////////////////////////////////////////////////////////////////
@@ -26,193 +21,349 @@ namespace zen {
class CbYamlWriter
{
public:
- explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_StrBuilder(InBuilder) { m_NodeStack.push_back(m_Tree.rootref()); }
+ explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_Builder(InBuilder) {}
void WriteField(CbFieldView Field)
{
- ryml::NodeRef Node;
+ CbValue Accessor = Field.GetValue();
+ CbFieldType Type = Accessor.GetType();
- if (m_IsFirst)
+ switch (Type)
{
- Node = Top();
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ WriteMapEntries(Field, 0);
+ break;
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ WriteSeqEntries(Field, 0);
+ break;
+ default:
+ WriteScalarValue(Field);
+ m_Builder << '\n';
+ break;
+ }
+ }
+
+ void WriteMapEntry(CbFieldView Field, int32_t Indent)
+ {
+ WriteIndent(Indent);
+ WriteMapEntryContent(Field, Indent);
+ }
+
+ void WriteSeqEntry(CbFieldView Field, int32_t Indent)
+ {
+ CbValue Accessor = Field.GetValue();
+ CbFieldType Type = Accessor.GetType();
- m_IsFirst = false;
+ if (Type == CbFieldType::Object || Type == CbFieldType::UniformObject)
+ {
+ bool First = true;
+ for (CbFieldView MapChild : Field)
+ {
+ if (First)
+ {
+ WriteIndent(Indent);
+ m_Builder << "- ";
+ First = false;
+ }
+ else
+ {
+ WriteIndent(Indent + 1);
+ }
+ WriteMapEntryContent(MapChild, Indent + 1);
+ }
+ }
+ else if (Type == CbFieldType::Array || Type == CbFieldType::UniformArray)
+ {
+ WriteIndent(Indent);
+ m_Builder << "-\n";
+ WriteSeqEntries(Field, Indent + 1);
}
else
{
- Node = Top().append_child();
+ WriteIndent(Indent);
+ m_Builder << "- ";
+ WriteScalarValue(Field);
+ m_Builder << '\n';
}
+ }
- if (std::u8string_view Name = Field.GetU8Name(); !Name.empty())
+private:
+ void WriteMapEntries(CbFieldView MapField, int32_t Indent)
+ {
+ for (CbFieldView Child : MapField)
{
- Node.set_key_serialized(ryml::csubstr((const char*)Name.data(), Name.size()));
+ WriteIndent(Indent);
+ WriteMapEntryContent(Child, Indent);
}
+ }
+
+ void WriteMapEntryContent(CbFieldView Field, int32_t Indent)
+ {
+ std::u8string_view Name = Field.GetU8Name();
+ m_Builder << std::string_view(reinterpret_cast<const char*>(Name.data()), Name.size());
- switch (CbValue Accessor = Field.GetValue(); Accessor.GetType())
+ CbValue Accessor = Field.GetValue();
+ CbFieldType Type = Accessor.GetType();
+
+ if (IsContainer(Type))
{
- case CbFieldType::Null:
- Node.set_val("null");
- break;
- case CbFieldType::Object:
- case CbFieldType::UniformObject:
- Node |= ryml::MAP;
- m_NodeStack.push_back(Node);
- for (CbFieldView It : Field)
+ m_Builder << ":\n";
+ WriteFieldValue(Field, Indent + 1);
+ }
+ else
+ {
+ m_Builder << ": ";
+ WriteScalarValue(Field);
+ m_Builder << '\n';
+ }
+ }
+
+ void WriteSeqEntries(CbFieldView SeqField, int32_t Indent)
+ {
+ for (CbFieldView Child : SeqField)
+ {
+ CbValue Accessor = Child.GetValue();
+ CbFieldType Type = Accessor.GetType();
+
+ if (Type == CbFieldType::Object || Type == CbFieldType::UniformObject)
+ {
+ bool First = true;
+ for (CbFieldView MapChild : Child)
{
- WriteField(It);
+ if (First)
+ {
+ WriteIndent(Indent);
+ m_Builder << "- ";
+ First = false;
+ }
+ else
+ {
+ WriteIndent(Indent + 1);
+ }
+ WriteMapEntryContent(MapChild, Indent + 1);
}
- m_NodeStack.pop_back();
+ }
+ else if (Type == CbFieldType::Array || Type == CbFieldType::UniformArray)
+ {
+ WriteIndent(Indent);
+ m_Builder << "-\n";
+ WriteSeqEntries(Child, Indent + 1);
+ }
+ else
+ {
+ WriteIndent(Indent);
+ m_Builder << "- ";
+ WriteScalarValue(Child);
+ m_Builder << '\n';
+ }
+ }
+ }
+
+ void WriteFieldValue(CbFieldView Field, int32_t Indent)
+ {
+ CbValue Accessor = Field.GetValue();
+ CbFieldType Type = Accessor.GetType();
+
+ switch (Type)
+ {
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ WriteMapEntries(Field, Indent);
break;
case CbFieldType::Array:
case CbFieldType::UniformArray:
- Node |= ryml::SEQ;
- m_NodeStack.push_back(Node);
- for (CbFieldView It : Field)
- {
- WriteField(It);
- }
- m_NodeStack.pop_back();
+ WriteSeqEntries(Field, Indent);
break;
- case CbFieldType::Binary:
- {
- ExtendableStringBuilder<256> Builder;
- const MemoryView Value = Accessor.AsBinary();
- ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024);
- const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize()));
- const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize));
- Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex);
-
- Node.set_key_serialized(Builder.c_str());
- }
+ case CbFieldType::CustomById:
+ WriteCustomById(Field.GetValue().AsCustomById(), Indent);
break;
- case CbFieldType::String:
- {
- const std::u8string_view U8String = Accessor.AsU8String();
- Node.set_val(ryml::csubstr((const char*)U8String.data(), U8String.size()));
- }
+ case CbFieldType::CustomByName:
+ WriteCustomByName(Field.GetValue().AsCustomByName(), Indent);
+ break;
+ default:
+ WriteScalarValue(Field);
+ m_Builder << '\n';
+ break;
+ }
+ }
+
+ void WriteScalarValue(CbFieldView Field)
+ {
+ CbValue Accessor = Field.GetValue();
+ switch (Accessor.GetType())
+ {
+ case CbFieldType::Null:
+ m_Builder << "null";
+ break;
+ case CbFieldType::BoolFalse:
+ m_Builder << "false";
+ break;
+ case CbFieldType::BoolTrue:
+ m_Builder << "true";
break;
case CbFieldType::IntegerPositive:
- Node << Accessor.AsIntegerPositive();
+ m_Builder << Accessor.AsIntegerPositive();
break;
case CbFieldType::IntegerNegative:
- Node << Accessor.AsIntegerNegative();
+ m_Builder << Accessor.AsIntegerNegative();
break;
case CbFieldType::Float32:
if (const float Value = Accessor.AsFloat32(); std::isfinite(Value))
- {
- Node << Value;
- }
+ m_Builder.Append(fmt::format("{}", Value));
else
- {
- Node << "null";
- }
+ m_Builder << "null";
break;
case CbFieldType::Float64:
if (const double Value = Accessor.AsFloat64(); std::isfinite(Value))
- {
- Node << Value;
- }
+ m_Builder.Append(fmt::format("{}", Value));
else
+ m_Builder << "null";
+ break;
+ case CbFieldType::String:
{
- Node << "null";
+ const std::u8string_view U8String = Accessor.AsU8String();
+ WriteString(std::string_view(reinterpret_cast<const char*>(U8String.data()), U8String.size()));
}
break;
- case CbFieldType::BoolFalse:
- Node << "false";
- break;
- case CbFieldType::BoolTrue:
- Node << "true";
+ case CbFieldType::Hash:
+ WriteString(Accessor.AsHash().ToHexString());
break;
case CbFieldType::ObjectAttachment:
case CbFieldType::BinaryAttachment:
- Node << Accessor.AsAttachment().ToHexString();
- break;
- case CbFieldType::Hash:
- Node << Accessor.AsHash().ToHexString();
+ WriteString(Accessor.AsAttachment().ToHexString());
break;
case CbFieldType::Uuid:
- Node << fmt::format("{}", Accessor.AsUuid());
+ WriteString(fmt::format("{}", Accessor.AsUuid()));
break;
case CbFieldType::DateTime:
- Node << DateTime(Accessor.AsDateTimeTicks()).ToIso8601();
+ WriteString(DateTime(Accessor.AsDateTimeTicks()).ToIso8601());
break;
case CbFieldType::TimeSpan:
if (const TimeSpan Span(Accessor.AsTimeSpanTicks()); Span.GetDays() == 0)
- {
- Node << Span.ToString("%h:%m:%s.%n");
- }
+ WriteString(Span.ToString("%h:%m:%s.%n"));
else
- {
- Node << Span.ToString("%d.%h:%m:%s.%n");
- }
+ WriteString(Span.ToString("%d.%h:%m:%s.%n"));
break;
case CbFieldType::ObjectId:
- Node << fmt::format("{}", Accessor.AsObjectId());
+ WriteString(fmt::format("{}", Accessor.AsObjectId()));
break;
- case CbFieldType::CustomById:
- {
- CbCustomById Custom = Accessor.AsCustomById();
+ case CbFieldType::Binary:
+ WriteBase64(Accessor.AsBinary());
+ break;
+ default:
+ ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType()));
+ break;
+ }
+ }
- Node |= ryml::MAP;
+ void WriteCustomById(CbCustomById Custom, int32_t Indent)
+ {
+ WriteIndent(Indent);
+ m_Builder << "Id: ";
+ m_Builder.Append(fmt::format("{}", Custom.Id));
+ m_Builder << '\n';
+
+ WriteIndent(Indent);
+ m_Builder << "Data: ";
+ WriteBase64(Custom.Data);
+ m_Builder << '\n';
+ }
- ryml::NodeRef IdNode = Node.append_child();
- IdNode.set_key("Id");
- IdNode.set_val_serialized(fmt::format("{}", Custom.Id));
+ void WriteCustomByName(CbCustomByName Custom, int32_t Indent)
+ {
+ WriteIndent(Indent);
+ m_Builder << "Name: ";
+ WriteString(std::string_view(reinterpret_cast<const char*>(Custom.Name.data()), Custom.Name.size()));
+ m_Builder << '\n';
+
+ WriteIndent(Indent);
+ m_Builder << "Data: ";
+ WriteBase64(Custom.Data);
+ m_Builder << '\n';
+ }
- ryml::NodeRef DataNode = Node.append_child();
- DataNode.set_key("Data");
+ void WriteBase64(MemoryView Value)
+ {
+ ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024);
+ ExtendableStringBuilder<256> Buf;
+ const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize()));
+ const size_t EncodedIndex = Buf.AddUninitialized(size_t(EncodedSize));
+ Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Buf.Data() + EncodedIndex);
+ WriteString(Buf.ToView());
+ }
- ExtendableStringBuilder<256> Builder;
- const MemoryView& Value = Custom.Data;
- const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize()));
- const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize));
- Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex);
+ void WriteString(std::string_view Str)
+ {
+ if (NeedsQuoting(Str))
+ {
+ m_Builder << '\'';
+ for (char C : Str)
+ {
+ if (C == '\'')
+ m_Builder << "''";
+ else
+ m_Builder << C;
+ }
+ m_Builder << '\'';
+ }
+ else
+ {
+ m_Builder << Str;
+ }
+ }
- DataNode.set_val_serialized(Builder.c_str());
- }
- break;
- case CbFieldType::CustomByName:
- {
- CbCustomByName Custom = Accessor.AsCustomByName();
+ void WriteIndent(int32_t Indent)
+ {
+ for (int32_t I = 0; I < Indent; ++I)
+ m_Builder << " ";
+ }
- Node |= ryml::MAP;
+ static bool NeedsQuoting(std::string_view Str)
+ {
+ if (Str.empty())
+ return false;
- ryml::NodeRef NameNode = Node.append_child();
- NameNode.set_key("Name");
- std::string_view Name = std::string_view((const char*)Custom.Name.data(), Custom.Name.size());
- NameNode.set_val_serialized(std::string(Name));
+ char First = Str[0];
+ if (First == ' ' || First == '\n' || First == '\t' || First == '\r' || First == '*' || First == '&' || First == '%' ||
+ First == '@' || First == '`')
+ return true;
- ryml::NodeRef DataNode = Node.append_child();
- DataNode.set_key("Data");
+ if (Str.size() >= 2 && Str[0] == '<' && Str[1] == '<')
+ return true;
- ExtendableStringBuilder<256> Builder;
- const MemoryView& Value = Custom.Data;
- const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize()));
- const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize));
- Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex);
+ char Last = Str.back();
+ if (Last == ' ' || Last == '\n' || Last == '\t' || Last == '\r')
+ return true;
- DataNode.set_val_serialized(Builder.c_str());
- }
- break;
- default:
- ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType()));
- break;
+ for (char C : Str)
+ {
+ if (C == '#' || C == ':' || C == '-' || C == '?' || C == ',' || C == '\n' || C == '{' || C == '}' || C == '[' || C == ']' ||
+ C == '\'' || C == '"')
+ return true;
}
- if (m_NodeStack.size() == 1)
+ return false;
+ }
+
+ static bool IsContainer(CbFieldType Type)
+ {
+ switch (Type)
{
- std::string Yaml = ryml::emitrs_yaml<std::string>(m_Tree);
- m_StrBuilder << Yaml;
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ case CbFieldType::CustomById:
+ case CbFieldType::CustomByName:
+ return true;
+ default:
+ return false;
}
}
-private:
- StringBuilderBase& m_StrBuilder;
- bool m_IsFirst = true;
-
- ryml::Tree m_Tree;
- std::vector<ryml::NodeRef> m_NodeStack;
- ryml::NodeRef& Top() { return m_NodeStack.back(); }
+ StringBuilderBase& m_Builder;
};
void
@@ -229,12 +380,40 @@ CompactBinaryToYaml(const CbArrayView& Array, StringBuilderBase& Builder)
Writer.WriteField(Array.AsFieldView());
}
+void
+CompactBinaryToYaml(MemoryView Data, StringBuilderBase& InBuilder)
+{
+ std::vector<CbFieldView> Fields = ReadCompactBinaryStream(Data);
+ if (Fields.empty())
+ return;
+
+ CbYamlWriter Writer(InBuilder);
+ if (Fields.size() == 1)
+ {
+ Writer.WriteField(Fields[0]);
+ return;
+ }
+
+ if (Fields[0].HasName())
+ {
+ for (const CbFieldView& Field : Fields)
+ Writer.WriteMapEntry(Field, 0);
+ }
+ else
+ {
+ for (const CbFieldView& Field : Fields)
+ Writer.WriteSeqEntry(Field, 0);
+ }
+}
+
#if ZEN_WITH_TESTS
void
cbyaml_forcelink()
{
}
+TEST_SUITE_BEGIN("core.compactbinaryyaml");
+
TEST_CASE("uson.yaml")
{
using namespace std::literals;
@@ -347,6 +526,8 @@ mixed_seq:
)"sv);
}
}
+
+TEST_SUITE_END();
#endif
} // namespace zen
diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp
index 252ac9045..ed2b16384 100644
--- a/src/zencore/compositebuffer.cpp
+++ b/src/zencore/compositebuffer.cpp
@@ -297,6 +297,9 @@ CompositeBuffer::IterateRange(uint64_t Offset,
}
#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("core.compositebuffer");
+
TEST_CASE("CompositeBuffer Null")
{
CompositeBuffer Buffer;
@@ -462,6 +465,8 @@ TEST_CASE("CompositeBuffer Composite")
TestIterateRange(8, 0, MakeMemoryView(FlatArray).Mid(8, 0), FlatView2);
}
+TEST_SUITE_END();
+
void
compositebuffer_forcelink()
{
diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp
index 25ed0fc46..6aa0adce0 100644
--- a/src/zencore/compress.cpp
+++ b/src/zencore/compress.cpp
@@ -2420,6 +2420,8 @@ private:
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.compress");
+
TEST_CASE("CompressedBuffer")
{
uint8_t Zeroes[1024]{};
@@ -2967,6 +2969,8 @@ TEST_CASE("CompressedBufferReader")
}
}
+TEST_SUITE_END();
+
void
compress_forcelink()
{
diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp
index 09eebb6ae..049854b42 100644
--- a/src/zencore/crypto.cpp
+++ b/src/zencore/crypto.cpp
@@ -449,6 +449,8 @@ crypto_forcelink()
{
}
+TEST_SUITE_BEGIN("core.crypto");
+
TEST_CASE("crypto.bits")
{
using CryptoBits256Bit = CryptoBits<256>;
@@ -500,6 +502,8 @@ TEST_CASE("crypto.aes")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp
index 92a065707..8ed63565c 100644
--- a/src/zencore/filesystem.cpp
+++ b/src/zencore/filesystem.cpp
@@ -194,7 +194,7 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles)
FindClose(hFind);
}
- return true;
+ return Success;
}
bool
@@ -1022,7 +1022,7 @@ TryCloneFile(const std::filesystem::path& FromPath, const std::filesystem::path&
return false;
}
fchmod(ToFd, 0666);
- ScopedFd $To = { FromFd };
+ ScopedFd $To = { ToFd };
ioctl(ToFd, FICLONE, FromFd);
@@ -1112,7 +1112,8 @@ CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToP
size_t FileSizeBytes = Stat.st_size;
- fchown(ToFd, Stat.st_uid, Stat.st_gid);
+ int $Ignore = fchown(ToFd, Stat.st_uid, Stat.st_gid);
+ ZEN_UNUSED($Ignore); // What's the appropriate error handling here?
// Copy impl
const size_t BufferSize = Min(FileSizeBytes, 64u << 10);
@@ -1326,11 +1327,6 @@ ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uin
{
BytesRead = size_t(dwNumberOfBytesRead);
}
- else if ((BytesRead != NumberOfBytesToRead))
- {
- Ec = MakeErrorCode(ERROR_HANDLE_EOF);
- return;
- }
else
{
Ec = MakeErrorCodeFromLastError();
@@ -1344,20 +1340,15 @@ ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uin
{
BytesRead = size_t(ReadResult);
}
- else if ((BytesRead != NumberOfBytesToRead))
- {
- Ec = MakeErrorCode(EIO);
- return;
- }
else
{
Ec = MakeErrorCodeFromLastError();
return;
}
#endif
- Size -= NumberOfBytesToRead;
- FileOffset += NumberOfBytesToRead;
- Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead;
+ Size -= BytesRead;
+ FileOffset += BytesRead;
+ Data = reinterpret_cast<uint8_t*>(Data) + BytesRead;
}
}
@@ -1408,7 +1399,7 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer
const uint64_t ChunkSize = Min<uint64_t>(WriteSize, uint64_t(2) * 1024 * 1024 * 1024);
#if ZEN_PLATFORM_WINDOWS
- hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(WriteSize));
+ hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(ChunkSize));
if (FAILED(hRes))
{
Outfile.Close();
@@ -1417,7 +1408,7 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer
ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str());
}
#else
- if (write(Fd, DataPtr, WriteSize) != int64_t(WriteSize))
+ if (write(Fd, DataPtr, ChunkSize) != int64_t(ChunkSize))
{
close(Fd);
std::error_code DummyEc;
@@ -3069,7 +3060,7 @@ SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly)
}
void
-MakeSafeAbsolutePathÍnPlace(std::filesystem::path& Path)
+MakeSafeAbsolutePathInPlace(std::filesystem::path& Path)
{
if (!Path.empty())
{
@@ -3091,7 +3082,7 @@ std::filesystem::path
MakeSafeAbsolutePath(const std::filesystem::path& Path)
{
std::filesystem::path Tmp(Path);
- MakeSafeAbsolutePathÍnPlace(Tmp);
+ MakeSafeAbsolutePathInPlace(Tmp);
return Tmp;
}
@@ -3319,6 +3310,8 @@ filesystem_forcelink()
{
}
+TEST_SUITE_BEGIN("core.filesystem");
+
TEST_CASE("filesystem")
{
using namespace std::filesystem;
@@ -3543,7 +3536,6 @@ TEST_CASE("PathBuilder")
Path.Reset();
Path.Append(fspath(L"/\u0119oo/"));
Path /= L"bar";
- printf("%ls\n", Path.ToPath().c_str());
CHECK(Path.ToView() == L"/\u0119oo/bar");
CHECK(Path.ToPath() == L"\\\u0119oo\\bar");
# endif
@@ -3614,6 +3606,8 @@ TEST_CASE("SharedMemory")
CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false));
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/include/zencore/base64.h b/src/zencore/include/zencore/base64.h
index 4d78b085f..08d9f3043 100644
--- a/src/zencore/include/zencore/base64.h
+++ b/src/zencore/include/zencore/base64.h
@@ -11,7 +11,11 @@ struct Base64
template<typename CharType>
static uint32_t Encode(const uint8_t* Source, uint32_t Length, CharType* Dest);
+ template<typename CharType>
+ static bool Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength);
+
static inline constexpr int32_t GetEncodedDataSize(uint32_t Size) { return ((Size + 2) / 3) * 4; }
+ static inline constexpr int32_t GetMaxDecodedDataSize(uint32_t Length) { return (Length / 4) * 3; }
};
} // namespace zen
diff --git a/src/zencore/include/zencore/blockingqueue.h b/src/zencore/include/zencore/blockingqueue.h
index e91fdc659..b6c93e937 100644
--- a/src/zencore/include/zencore/blockingqueue.h
+++ b/src/zencore/include/zencore/blockingqueue.h
@@ -2,6 +2,8 @@
#pragma once
+#include <zencore/zencore.h> // For ZEN_ASSERT
+
#include <atomic>
#include <condition_variable>
#include <deque>
diff --git a/src/zencore/include/zencore/compactbinaryfile.h b/src/zencore/include/zencore/compactbinaryfile.h
index 00c37e941..33f3e7bea 100644
--- a/src/zencore/include/zencore/compactbinaryfile.h
+++ b/src/zencore/include/zencore/compactbinaryfile.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include <zencore/compactbinary.h>
#include <zencore/iohash.h>
diff --git a/src/zencore/include/zencore/compactbinaryvalue.h b/src/zencore/include/zencore/compactbinaryvalue.h
index aa2d2821d..4ce8009b8 100644
--- a/src/zencore/include/zencore/compactbinaryvalue.h
+++ b/src/zencore/include/zencore/compactbinaryvalue.h
@@ -128,17 +128,21 @@ CbValue::AsString(CbFieldError* OutError, std::string_view Default) const
uint32_t ValueSizeByteCount;
const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount);
- if (OutError)
+ if (ValueSize >= (uint64_t(1) << 31))
{
- if (ValueSize >= (uint64_t(1) << 31))
+ if (OutError)
{
*OutError = CbFieldError::RangeError;
- return Default;
}
+ return Default;
+ }
+
+ if (OutError)
+ {
*OutError = CbFieldError::None;
}
- return std::string_view(Chars + ValueSizeByteCount, int32_t(ValueSize));
+ return std::string_view(Chars + ValueSizeByteCount, size_t(ValueSize));
}
inline std::u8string_view
@@ -148,17 +152,21 @@ CbValue::AsU8String(CbFieldError* OutError, std::u8string_view Default) const
uint32_t ValueSizeByteCount;
const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount);
- if (OutError)
+ if (ValueSize >= (uint64_t(1) << 31))
{
- if (ValueSize >= (uint64_t(1) << 31))
+ if (OutError)
{
*OutError = CbFieldError::RangeError;
- return Default;
}
+ return Default;
+ }
+
+ if (OutError)
+ {
*OutError = CbFieldError::None;
}
- return std::u8string_view(Chars + ValueSizeByteCount, int32_t(ValueSize));
+ return std::u8string_view(Chars + ValueSizeByteCount, size_t(ValueSize));
}
inline uint64_t
diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h
index f28863679..16e2b59f8 100644
--- a/src/zencore/include/zencore/filesystem.h
+++ b/src/zencore/include/zencore/filesystem.h
@@ -64,80 +64,80 @@ std::filesystem::path PathFromHandle(void* NativeHandle, std::error_code& Ec);
*/
std::filesystem::path CanonicalPath(std::filesystem::path InPath, std::error_code& Ec);
-/** Query file size
+/** Check if a path exists and is a regular file (throws)
*/
bool IsFile(const std::filesystem::path& Path);
-/** Query file size
+/** Check if a path exists and is a regular file (does not throw)
*/
bool IsFile(const std::filesystem::path& Path, std::error_code& Ec);
-/** Query file size
+/** Check if a path exists and is a directory (throws)
*/
bool IsDir(const std::filesystem::path& Path);
-/** Query file size
+/** Check if a path exists and is a directory (does not throw)
*/
bool IsDir(const std::filesystem::path& Path, std::error_code& Ec);
-/** Query file size
+/** Delete file at path, if it exists (throws)
*/
bool RemoveFile(const std::filesystem::path& Path);
-/** Query file size
+/** Delete file at path, if it exists (does not throw)
*/
bool RemoveFile(const std::filesystem::path& Path, std::error_code& Ec);
-/** Query file size
+/** Delete directory at path, if it exists (throws)
*/
bool RemoveDir(const std::filesystem::path& Path);
-/** Query file size
+/** Delete directory at path, if it exists (does not throw)
*/
bool RemoveDir(const std::filesystem::path& Path, std::error_code& Ec);
-/** Query file size
+/** Query file size (throws)
*/
uint64_t FileSizeFromPath(const std::filesystem::path& Path);
-/** Query file size
+/** Query file size (does not throw)
*/
uint64_t FileSizeFromPath(const std::filesystem::path& Path, std::error_code& Ec);
-/** Query file size from native file handle
+/** Query file size from native file handle (throws)
*/
uint64_t FileSizeFromHandle(void* NativeHandle);
-/** Query file size from native file handle
+/** Query file size from native file handle (does not throw)
*/
uint64_t FileSizeFromHandle(void* NativeHandle, std::error_code& Ec);
/** Get a native time tick of last modification time
*/
-uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec);
+uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename);
/** Get a native time tick of last modification time
*/
-uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename);
+uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec);
bool TryGetFileProperties(const std::filesystem::path& Path,
uint64_t& OutSize,
uint64_t& OutModificationTick,
uint32_t& OutNativeModeOrAttributes);
-/** Move a file, if the files are not on the same drive the function will fail
+/** Move/rename a file, if the files are not on the same drive the function will fail (throws)
*/
void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath);
-/** Move a file, if the files are not on the same drive the function will fail
+/** Move/rename a file, if the files are not on the same drive the function will fail
*/
void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec);
-/** Move a directory, if the files are not on the same drive the function will fail
+/** Move/rename a directory, if the files are not on the same drive the function will fail (throws)
*/
void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath);
-/** Move a directory, if the files are not on the same drive the function will fail
+/** Move/rename a directory, if the files are not on the same drive the function will fail
*/
void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec);
@@ -421,7 +421,7 @@ uint32_t MakeFileModeReadOnly(uint32_t FileMode, bool ReadOnly);
bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly, std::error_code& Ec);
bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly);
-void MakeSafeAbsolutePathÍnPlace(std::filesystem::path& Path);
+void MakeSafeAbsolutePathInPlace(std::filesystem::path& Path);
[[nodiscard]] std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path);
class SharedMemory
diff --git a/src/zencore/include/zencore/hashutils.h b/src/zencore/include/zencore/hashutils.h
index 4e877e219..8abfd4b6e 100644
--- a/src/zencore/include/zencore/hashutils.h
+++ b/src/zencore/include/zencore/hashutils.h
@@ -2,6 +2,10 @@
#pragma once
+#include <cstddef>
+#include <functional>
+#include <type_traits>
+
namespace zen {
template<typename T>
diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h
index 182768ff6..82c201edd 100644
--- a/src/zencore/include/zencore/iobuffer.h
+++ b/src/zencore/include/zencore/iobuffer.h
@@ -426,22 +426,39 @@ private:
class IoBufferBuilder
{
public:
- static IoBuffer MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset = 0, uint64_t Size = ~0ull);
- static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName);
- static IoBuffer MakeFromFileHandle(void* FileHandle, uint64_t Offset = 0, uint64_t Size = ~0ull);
- /** Make sure buffer data is memory resident, but avoid memory mapping data from files
- */
- static IoBuffer ReadFromFileMaybe(const IoBuffer& InBuffer);
- inline static IoBuffer MakeFromMemory(MemoryView Memory) { return IoBuffer(IoBuffer::Wrap, Memory.GetData(), Memory.GetSize()); }
- inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz)
+ static IoBuffer MakeFromFile(const std::filesystem::path& FileName,
+ uint64_t Offset = 0,
+ uint64_t Size = ~0ull,
+ ZenContentType ContentType = ZenContentType::kBinary);
+ static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName, ZenContentType ContentType = ZenContentType::kBinary);
+ static IoBuffer MakeFromFileHandle(void* FileHandle,
+ uint64_t Offset = 0,
+ uint64_t Size = ~0ull,
+ ZenContentType ContentType = ZenContentType::kBinary);
+ inline static IoBuffer MakeFromMemory(MemoryView Memory, ZenContentType ContentType = ZenContentType::kBinary)
+ {
+ IoBuffer NewBuffer(IoBuffer::Wrap, Memory.GetData(), Memory.GetSize());
+ NewBuffer.SetContentType(ContentType);
+ return NewBuffer;
+ }
+ inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz, ZenContentType ContentType = ZenContentType::kBinary)
{
if (Sz)
{
- return IoBuffer(IoBuffer::Clone, Ptr, Sz);
+ IoBuffer NewBuffer(IoBuffer::Clone, Ptr, Sz);
+ NewBuffer.SetContentType(ContentType);
+ return NewBuffer;
}
return {};
}
- inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize()); }
+ inline static IoBuffer MakeCloneFromMemory(MemoryView Memory, ZenContentType ContentType = ZenContentType::kBinary)
+ {
+ return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize(), ContentType);
+ }
+
+ /** Make sure buffer data is memory resident, but avoid memory mapping data from files
+ */
+ static IoBuffer ReadFromFileMaybe(const IoBuffer& InBuffer);
};
void iobuffer_forcelink();
diff --git a/src/zencore/include/zencore/logbase.h b/src/zencore/include/zencore/logbase.h
index 00af68b0a..ece17a85e 100644
--- a/src/zencore/include/zencore/logbase.h
+++ b/src/zencore/include/zencore/logbase.h
@@ -4,96 +4,85 @@
#include <string_view>
-#define ZEN_LOG_LEVEL_TRACE 0
-#define ZEN_LOG_LEVEL_DEBUG 1
-#define ZEN_LOG_LEVEL_INFO 2
-#define ZEN_LOG_LEVEL_WARN 3
-#define ZEN_LOG_LEVEL_ERROR 4
-#define ZEN_LOG_LEVEL_CRITICAL 5
-#define ZEN_LOG_LEVEL_OFF 6
-
-#define ZEN_LEVEL_NAME_TRACE std::string_view("trace", 5)
-#define ZEN_LEVEL_NAME_DEBUG std::string_view("debug", 5)
-#define ZEN_LEVEL_NAME_INFO std::string_view("info", 4)
-#define ZEN_LEVEL_NAME_WARNING std::string_view("warning", 7)
-#define ZEN_LEVEL_NAME_ERROR std::string_view("error", 5)
-#define ZEN_LEVEL_NAME_CRITICAL std::string_view("critical", 8)
-#define ZEN_LEVEL_NAME_OFF std::string_view("off", 3)
-
-namespace zen::logging::level {
+namespace zen::logging {
enum LogLevel : int
{
- Trace = ZEN_LOG_LEVEL_TRACE,
- Debug = ZEN_LOG_LEVEL_DEBUG,
- Info = ZEN_LOG_LEVEL_INFO,
- Warn = ZEN_LOG_LEVEL_WARN,
- Err = ZEN_LOG_LEVEL_ERROR,
- Critical = ZEN_LOG_LEVEL_CRITICAL,
- Off = ZEN_LOG_LEVEL_OFF,
+ Trace,
+ Debug,
+ Info,
+ Warn,
+ Err,
+ Critical,
+ Off,
LogLevelCount
};
LogLevel ParseLogLevelString(std::string_view String);
std::string_view ToStringView(LogLevel Level);
-} // namespace zen::logging::level
-
-namespace zen::logging {
-
-void SetLogLevel(level::LogLevel NewLogLevel);
-level::LogLevel GetLogLevel();
+void SetLogLevel(LogLevel NewLogLevel);
+LogLevel GetLogLevel();
-} // namespace zen::logging
+struct SourceLocation
+{
+ constexpr SourceLocation() = default;
+ constexpr SourceLocation(const char* InFilename, int InLine) : Filename(InFilename), Line(InLine) {}
-namespace spdlog {
-class logger;
-}
+ constexpr operator bool() const noexcept { return Line != 0; }
-namespace zen::logging {
+ const char* Filename{nullptr};
+ int Line{0};
+};
-struct SourceLocation
+/** This encodes the constant parts of a log message which can be emitted once
+ * and then referred to by log events.
+ *
+ * It's *critical* that instances of this struct are permanent and never
+ * destroyed, as log messages will refer to them by pointer. The easiest way
+ * to ensure this is to create them as function-local statics.
+ *
+ * The logging macros already do this for you so this should not be something
+ * you normally would need to worry about.
+ */
+struct LogPoint
{
- constexpr SourceLocation() = default;
- constexpr SourceLocation(const char* filename_in, int line_in, const char* funcname_in)
- : filename(filename_in)
- , line(line_in)
- , funcname(funcname_in)
- {
- }
-
- constexpr bool empty() const noexcept { return line == 0; }
-
- // IMPORTANT NOTE: the layout of this class must match the spdlog::source_loc class
- // since we currently pass a pointer to it into spdlog after casting it to
- // spdlog::source_loc*
- //
- // This is intended to be an intermediate state, before we (probably) transition off
- // spdlog entirely
-
- const char* filename{nullptr};
- int line{0};
- const char* funcname{nullptr};
+ SourceLocation Location;
+ LogLevel Level;
+ std::string_view FormatString;
};
+class Logger;
+
} // namespace zen::logging
namespace zen {
+// Lightweight non-owning handle to a Logger. Loggers are owned by the Registry
+// via Ref<Logger>; LoggerRef exists as a cheap (raw pointer) handle that can be
+// stored in members and passed through logging macros without requiring the
+// complete Logger type or incurring refcount overhead on every log call.
struct LoggerRef
{
LoggerRef() = default;
- LoggerRef(spdlog::logger& InLogger) : SpdLogger(&InLogger) {}
+ LoggerRef(logging::Logger& InLogger) : m_Logger(&InLogger) {}
+ // This exists so that logging macros can pass LoggerRef or LogCategory
+ // to ZEN_LOG without needing to know which one it is
LoggerRef Logger() { return *this; }
- bool ShouldLog(int Level) const;
- inline operator bool() const { return SpdLogger != nullptr; }
+ bool ShouldLog(logging::LogLevel Level) const;
+ inline operator bool() const { return m_Logger != nullptr; }
+
+ inline logging::Logger* operator->() const { return m_Logger; }
+ inline logging::Logger& operator*() const { return *m_Logger; }
- void SetLogLevel(logging::level::LogLevel NewLogLevel);
- logging::level::LogLevel GetLogLevel();
+ void SetLogLevel(logging::LogLevel NewLogLevel);
+ logging::LogLevel GetLogLevel();
+ void Flush();
- spdlog::logger* SpdLogger = nullptr;
+private:
+ logging::Logger* m_Logger = nullptr;
};
} // namespace zen
diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h
index afbbbd3ee..4b593c19e 100644
--- a/src/zencore/include/zencore/logging.h
+++ b/src/zencore/include/zencore/logging.h
@@ -9,16 +9,9 @@
#if ZEN_PLATFORM_WINDOWS
# define ZEN_LOG_SECTION(Id) ZEN_DATA_SECTION(Id)
-# pragma section(".zlog$f", read)
# pragma section(".zlog$l", read)
-# pragma section(".zlog$m", read)
-# pragma section(".zlog$s", read)
-# define ZEN_DECLARE_FUNCTION static constinit ZEN_LOG_SECTION(".zlog$f") char FuncName[] = __FUNCTION__;
-# define ZEN_LOG_FUNCNAME FuncName
#else
# define ZEN_LOG_SECTION(Id)
-# define ZEN_DECLARE_FUNCTION
-# define ZEN_LOG_FUNCNAME static_cast<const char*>(__func__)
#endif
namespace zen::logging {
@@ -31,39 +24,35 @@ void FlushLogging();
LoggerRef Default();
void SetDefault(std::string_view NewDefaultLoggerId);
LoggerRef ConsoleLog();
+void ResetConsoleLog();
void SuppressConsoleLog();
LoggerRef ErrorLog();
void SetErrorLog(std::string_view LoggerId);
LoggerRef Get(std::string_view Name);
-void ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers);
+void ConfigureLogLevels(LogLevel Level, std::string_view Loggers);
void RefreshLogLevels();
-void RefreshLogLevels(level::LogLevel DefaultLevel);
-
+void RefreshLogLevels(LogLevel DefaultLevel);
+
+/** LogCategory allows for the creation of log categories that can be used with
+ * the logging macros just like a logger reference. The main purpose of this is
+ * to allow for static log categories in global scope where we can't actually
+ * go ahead and instantiate a logger immediately because the logging system may
+ * not be initialized yet.
+ */
struct LogCategory
{
- inline LogCategory(std::string_view InCategory) : CategoryName(InCategory) {}
-
- inline zen::LoggerRef Logger()
- {
- if (LoggerRef)
- {
- return LoggerRef;
- }
+ inline LogCategory(std::string_view InCategory) : m_CategoryName(InCategory) {}
- LoggerRef = zen::logging::Get(CategoryName);
- return LoggerRef;
- }
+ LoggerRef Logger();
- std::string CategoryName;
- zen::LoggerRef LoggerRef;
+private:
+ std::string m_CategoryName;
+ LoggerRef m_LoggerRef;
};
-void EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args);
-void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Message);
-void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Message);
-void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args);
-void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Format, fmt::format_args Args);
+void EmitConsoleLogMessage(const LogPoint& Lp, fmt::format_args Args);
+void EmitLogMessage(LoggerRef& Logger, const LogPoint& Lp, fmt::format_args Args);
template<typename... T>
auto
@@ -78,15 +67,14 @@ namespace zen {
extern LoggerRef TheDefaultLogger;
-inline LoggerRef
-Log()
-{
- if (TheDefaultLogger)
- {
- return TheDefaultLogger;
- }
- return zen::logging::ConsoleLog();
-}
+/**
+ * This is the default logger, which any ZEN_INFO et al will get if there's
+ * no Log() function declared in the current scope.
+ *
+ * Typically, classes which want to log to its own channel will declare a Log()
+ * member function which returns a LoggerRef created at construction time.
+ */
+LoggerRef Log();
using logging::ConsoleLog;
using logging::ErrorLog;
@@ -97,12 +85,6 @@ using zen::ConsoleLog;
using zen::ErrorLog;
using zen::Log;
-inline consteval bool
-LogIsErrorLevel(int LogLevel)
-{
- return (LogLevel == zen::logging::level::Err || LogLevel == zen::logging::level::Critical);
-};
-
#if ZEN_BUILD_DEBUG
# define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \
while (false) \
@@ -116,75 +98,66 @@ LogIsErrorLevel(int LogLevel)
}
#endif
-#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- ZEN_DECLARE_FUNCTION \
- static constinit ZEN_LOG_SECTION(".zlog$s") char FileName[] = __FILE__; \
- static constinit ZEN_LOG_SECTION(".zlog$m") char FormatString[] = fmtstr; \
- static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::SourceLocation Location{FileName, __LINE__, ZEN_LOG_FUNCNAME}; \
- zen::LoggerRef Logger = InLogger; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- if (Logger.ShouldLog(InLevel)) \
- { \
- zen::logging::EmitLogMessage(Logger, \
- Location, \
- InLevel, \
- std::string_view(FormatString, sizeof FormatString - 1), \
- zen::logging::LogCaptureArguments(__VA_ARGS__)); \
- } \
+#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit ZEN_LOG_SECTION(".zlog$l") \
+ zen::logging::LogPoint LogPoint{zen::logging::SourceLocation{__FILE__, __LINE__}, InLevel, std::string_view(fmtstr)}; \
+ zen::LoggerRef Logger = InLogger; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ if (Logger.ShouldLog(InLevel)) \
+ { \
+ zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+ } \
} while (false);
-#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- static constinit ZEN_LOG_SECTION(".zlog$m") char FormatString[] = fmtstr; \
- zen::LoggerRef Logger = InLogger; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- if (Logger.ShouldLog(InLevel)) \
- { \
- zen::logging::EmitLogMessage(Logger, \
- InLevel, \
- std::string_view(FormatString, sizeof FormatString - 1), \
- zen::logging::LogCaptureArguments(__VA_ARGS__)); \
- } \
+#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \
+ zen::LoggerRef Logger = InLogger; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ if (Logger.ShouldLog(InLevel)) \
+ { \
+ zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+ } \
} while (false);
#define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \
static zen::logging::LogCategory Category { Name }
-#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Trace, fmtstr, ##__VA_ARGS__)
-#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__)
-#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Info, fmtstr, ##__VA_ARGS__)
-#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__)
-#define ZEN_LOG_ERROR(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::level::Err, fmtstr, ##__VA_ARGS__)
-#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \
- ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::level::Critical, fmtstr, ##__VA_ARGS__)
-
-#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Trace, fmtstr, ##__VA_ARGS__)
-#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__)
-#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Info, fmtstr, ##__VA_ARGS__)
-#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__)
-#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Err, fmtstr, ##__VA_ARGS__)
-#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Critical, fmtstr, ##__VA_ARGS__)
-
-#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- zen::logging::EmitConsoleLogMessage(InLevel, fmtstr, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Trace, fmtstr, ##__VA_ARGS__)
+#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Debug, fmtstr, ##__VA_ARGS__)
+#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Info, fmtstr, ##__VA_ARGS__)
+#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Warn, fmtstr, ##__VA_ARGS__)
+#define ZEN_LOG_ERROR(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::Err, fmtstr, ##__VA_ARGS__)
+#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::Critical, fmtstr, ##__VA_ARGS__)
+
+#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Trace, fmtstr, ##__VA_ARGS__)
+#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Debug, fmtstr, ##__VA_ARGS__)
+#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Info, fmtstr, ##__VA_ARGS__)
+#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Warn, fmtstr, ##__VA_ARGS__)
+#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr, ##__VA_ARGS__)
+#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr, ##__VA_ARGS__)
+
+#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
} while (false)
-#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_TRACE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Trace, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_DEBUG(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Debug, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_INFO(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_WARN(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Warn, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_ERROR(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Err, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_CRITICAL(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Critical, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE_TRACE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Trace, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE_DEBUG(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Debug, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE_INFO(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE_WARN(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Warn, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE_ERROR(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Err, fmtstr, ##__VA_ARGS__)
+#define ZEN_CONSOLE_CRITICAL(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Critical, fmtstr, ##__VA_ARGS__)
//////////////////////////////////////////////////////////////////////////
@@ -239,28 +212,28 @@ std::string_view EmitActivitiesForLogging(StringBuilderBase& OutString);
#define ZEN_LOG_SCOPE(...) ScopedLazyActivity $Activity##__LINE__([&](StringBuilderBase& Out) { Out << fmt::format(__VA_ARGS__); })
-#define ZEN_SCOPED_WARN(fmtstr, ...) \
- do \
- { \
- ExtendableStringBuilder<256> ScopeString; \
- const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \
- ZEN_LOG(Log(), zen::logging::level::Warn, fmtstr "{}", ##__VA_ARGS__, Scopes); \
+#define ZEN_SCOPED_WARN(fmtstr, ...) \
+ do \
+ { \
+ ExtendableStringBuilder<256> ScopeString; \
+ const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \
+ ZEN_LOG(Log(), zen::logging::Warn, fmtstr "{}", ##__VA_ARGS__, Scopes); \
} while (false)
-#define ZEN_SCOPED_ERROR(fmtstr, ...) \
- do \
- { \
- ExtendableStringBuilder<256> ScopeString; \
- const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \
- ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Err, fmtstr "{}", ##__VA_ARGS__, Scopes); \
+#define ZEN_SCOPED_ERROR(fmtstr, ...) \
+ do \
+ { \
+ ExtendableStringBuilder<256> ScopeString; \
+ const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \
+ ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr "{}", ##__VA_ARGS__, Scopes); \
} while (false)
-#define ZEN_SCOPED_CRITICAL(fmtstr, ...) \
- do \
- { \
- ExtendableStringBuilder<256> ScopeString; \
- const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \
- ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Critical, fmtstr "{}", ##__VA_ARGS__, Scopes); \
+#define ZEN_SCOPED_CRITICAL(fmtstr, ...) \
+ do \
+ { \
+ ExtendableStringBuilder<256> ScopeString; \
+ const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \
+ ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr "{}", ##__VA_ARGS__, Scopes); \
} while (false)
ScopedActivityBase* GetThreadActivity();
diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h
new file mode 100644
index 000000000..5060a8393
--- /dev/null
+++ b/src/zencore/include/zencore/logging/ansicolorsink.h
@@ -0,0 +1,33 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/sink.h>
+
+#include <memory>
+
+namespace zen::logging {
+
+enum class ColorMode
+{
+ On,
+ Off,
+ Auto
+};
+
+class AnsiColorStdoutSink : public Sink
+{
+public:
+ explicit AnsiColorStdoutSink(ColorMode Mode = ColorMode::Auto);
+ ~AnsiColorStdoutSink() override;
+
+ void Log(const LogMessage& Msg) override;
+ void Flush() override;
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter) override;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/asyncsink.h b/src/zencore/include/zencore/logging/asyncsink.h
new file mode 100644
index 000000000..c49a1ccce
--- /dev/null
+++ b/src/zencore/include/zencore/logging/asyncsink.h
@@ -0,0 +1,30 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/sink.h>
+
+#include <memory>
+#include <vector>
+
+namespace zen::logging {
+
+class AsyncSink : public Sink
+{
+public:
+ explicit AsyncSink(std::vector<SinkPtr> InSinks);
+ ~AsyncSink() override;
+
+ AsyncSink(const AsyncSink&) = delete;
+ AsyncSink& operator=(const AsyncSink&) = delete;
+
+ void Log(const LogMessage& Msg) override;
+ void Flush() override;
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter) override;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/formatter.h b/src/zencore/include/zencore/logging/formatter.h
new file mode 100644
index 000000000..11904d71d
--- /dev/null
+++ b/src/zencore/include/zencore/logging/formatter.h
@@ -0,0 +1,20 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/logmsg.h>
+#include <zencore/logging/memorybuffer.h>
+
+#include <memory>
+
+namespace zen::logging {
+
+class Formatter
+{
+public:
+ virtual ~Formatter() = default;
+ virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) = 0;
+ virtual std::unique_ptr<Formatter> Clone() const = 0;
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h
new file mode 100644
index 000000000..ce021e1a5
--- /dev/null
+++ b/src/zencore/include/zencore/logging/helpers.h
@@ -0,0 +1,122 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logbase.h>
+#include <zencore/logging/memorybuffer.h>
+
+#include <chrono>
+#include <ctime>
+#include <string_view>
+
+namespace zen::logging::helpers {
+
+inline void
+AppendStringView(std::string_view Sv, MemoryBuffer& Dest)
+{
+ Dest.append(Sv.data(), Sv.data() + Sv.size());
+}
+
+inline void
+AppendInt(int N, MemoryBuffer& Dest)
+{
+ fmt::format_int Formatted(N);
+ Dest.append(Formatted.data(), Formatted.data() + Formatted.size());
+}
+
+inline void
+Pad2(int N, MemoryBuffer& Dest)
+{
+ if (N >= 0 && N < 100)
+ {
+ Dest.push_back(static_cast<char>('0' + N / 10));
+ Dest.push_back(static_cast<char>('0' + N % 10));
+ }
+ else
+ {
+ fmt::format_int Formatted(N);
+ Dest.append(Formatted.data(), Formatted.data() + Formatted.size());
+ }
+}
+
+inline void
+Pad3(uint32_t N, MemoryBuffer& Dest)
+{
+ if (N < 1000)
+ {
+ Dest.push_back(static_cast<char>('0' + N / 100));
+ Dest.push_back(static_cast<char>('0' + (N / 10) % 10));
+ Dest.push_back(static_cast<char>('0' + N % 10));
+ }
+ else
+ {
+ AppendInt(static_cast<int>(N), Dest);
+ }
+}
+
+inline void
+PadUint(size_t N, unsigned int Width, MemoryBuffer& Dest)
+{
+ fmt::format_int Formatted(N);
+ auto StrLen = static_cast<unsigned int>(Formatted.size());
+ if (Width > StrLen)
+ {
+ for (unsigned int Pad = 0; Pad < Width - StrLen; ++Pad)
+ {
+ Dest.push_back('0');
+ }
+ }
+ Dest.append(Formatted.data(), Formatted.data() + Formatted.size());
+}
+
+template<typename ToDuration>
+inline ToDuration
+TimeFraction(std::chrono::system_clock::time_point Tp)
+{
+ using std::chrono::duration_cast;
+ using std::chrono::seconds;
+ auto Duration = Tp.time_since_epoch();
+ auto Secs = duration_cast<seconds>(Duration);
+ return duration_cast<ToDuration>(Duration) - duration_cast<ToDuration>(Secs);
+}
+
+inline std::tm
+SafeLocaltime(std::time_t Time)
+{
+ std::tm Result{};
+#if defined(_WIN32)
+ localtime_s(&Result, &Time);
+#else
+ localtime_r(&Time, &Result);
+#endif
+ return Result;
+}
+
+inline const char*
+ShortFilename(const char* Path)
+{
+ if (Path == nullptr)
+ {
+ return Path;
+ }
+
+ const char* It = Path;
+ const char* LastSep = Path;
+ while (*It)
+ {
+ if (*It == '/' || *It == '\\')
+ {
+ LastSep = It + 1;
+ }
+ ++It;
+ }
+ return LastSep;
+}
+
+inline std::string_view
+LevelToShortString(LogLevel Level)
+{
+ return ToStringView(Level);
+}
+
+} // namespace zen::logging::helpers
diff --git a/src/zencore/include/zencore/logging/logger.h b/src/zencore/include/zencore/logging/logger.h
new file mode 100644
index 000000000..39d1139a5
--- /dev/null
+++ b/src/zencore/include/zencore/logging/logger.h
@@ -0,0 +1,63 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/sink.h>
+
+#include <atomic>
+#include <memory>
+#include <string_view>
+
+namespace zen::logging {
+
+class ErrorHandler
+{
+public:
+ virtual ~ErrorHandler() = default;
+ virtual void HandleError(const std::string_view& Msg) = 0;
+};
+
+class Logger : public RefCounted
+{
+public:
+ Logger(std::string_view InName, SinkPtr InSink);
+ Logger(std::string_view InName, std::span<const SinkPtr> InSinks);
+ ~Logger();
+
+ Logger(const Logger&) = delete;
+ Logger& operator=(const Logger&) = delete;
+
+ void Log(const LogPoint& Point, fmt::format_args Args);
+
+ bool ShouldLog(LogLevel InLevel) const { return InLevel >= m_Level.load(std::memory_order_relaxed); }
+
+ void SetLevel(LogLevel InLevel) { m_Level.store(InLevel, std::memory_order_relaxed); }
+ LogLevel GetLevel() const { return m_Level.load(std::memory_order_relaxed); }
+
+ void SetFlushLevel(LogLevel InLevel) { m_FlushLevel.store(InLevel, std::memory_order_relaxed); }
+ LogLevel GetFlushLevel() const { return m_FlushLevel.load(std::memory_order_relaxed); }
+
+ std::string_view Name() const;
+
+ void SetSinks(std::vector<SinkPtr> InSinks);
+ void AddSink(SinkPtr InSink);
+
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter);
+
+ void SetErrorHandler(ErrorHandler* Handler);
+
+ void Flush();
+
+ Ref<Logger> Clone(std::string_view NewName) const;
+
+private:
+ void SinkIt(const LogMessage& Msg);
+ void FlushIfNeeded(LogLevel InLevel);
+
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+ std::atomic<LogLevel> m_Level{Info};
+ std::atomic<LogLevel> m_FlushLevel{Off};
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h
new file mode 100644
index 000000000..1d8b6b1b7
--- /dev/null
+++ b/src/zencore/include/zencore/logging/logmsg.h
@@ -0,0 +1,66 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logbase.h>
+
+#include <chrono>
+#include <string_view>
+
+namespace zen::logging {
+
+using LogClock = std::chrono::system_clock;
+
+struct LogMessage
+{
+ LogMessage() = default;
+
+ LogMessage(const LogPoint& InPoint, std::string_view InLoggerName, std::string_view InPayload)
+ : m_LoggerName(InLoggerName)
+ , m_Level(InPoint.Level)
+ , m_Time(LogClock::now())
+ , m_Source(InPoint.Location)
+ , m_Payload(InPayload)
+ , m_Point(&InPoint)
+ {
+ }
+
+ std::string_view GetPayload() const { return m_Payload; }
+ int GetThreadId() const { return m_ThreadId; }
+ LogClock::time_point GetTime() const { return m_Time; }
+ LogLevel GetLevel() const { return m_Level; }
+ std::string_view GetLoggerName() const { return m_LoggerName; }
+ const SourceLocation& GetSource() const { return m_Source; }
+ const LogPoint& GetLogPoint() const { return *m_Point; }
+
+ void SetThreadId(int InThreadId) { m_ThreadId = InThreadId; }
+ void SetPayload(std::string_view InPayload) { m_Payload = InPayload; }
+ void SetLoggerName(std::string_view InName) { m_LoggerName = InName; }
+ void SetLevel(LogLevel InLevel) { m_Level = InLevel; }
+ void SetTime(LogClock::time_point InTime) { m_Time = InTime; }
+ void SetSource(const SourceLocation& InSource) { m_Source = InSource; }
+
+ mutable size_t ColorRangeStart = 0;
+ mutable size_t ColorRangeEnd = 0;
+
+private:
+ static constexpr LogPoint s_DefaultPoints[LogLevelCount] = {
+ {{}, Trace, {}},
+ {{}, Debug, {}},
+ {{}, Info, {}},
+ {{}, Warn, {}},
+ {{}, Err, {}},
+ {{}, Critical, {}},
+ {{}, Off, {}},
+ };
+
+ std::string_view m_LoggerName;
+ LogLevel m_Level = Off;
+ std::chrono::system_clock::time_point m_Time;
+ SourceLocation m_Source;
+ std::string_view m_Payload;
+ const LogPoint* m_Point = &s_DefaultPoints[Off];
+ int m_ThreadId = 0;
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/memorybuffer.h b/src/zencore/include/zencore/logging/memorybuffer.h
new file mode 100644
index 000000000..cd0ff324f
--- /dev/null
+++ b/src/zencore/include/zencore/logging/memorybuffer.h
@@ -0,0 +1,11 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <fmt/format.h>
+
+namespace zen::logging {
+
+using MemoryBuffer = fmt::basic_memory_buffer<char, 250>;
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/messageonlyformatter.h b/src/zencore/include/zencore/logging/messageonlyformatter.h
new file mode 100644
index 000000000..ce25fe9a6
--- /dev/null
+++ b/src/zencore/include/zencore/logging/messageonlyformatter.h
@@ -0,0 +1,22 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/formatter.h>
+#include <zencore/logging/helpers.h>
+
+namespace zen::logging {
+
+class MessageOnlyFormatter : public Formatter
+{
+public:
+ void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
+ {
+ helpers::AppendStringView(Msg.GetPayload(), Dest);
+ Dest.push_back('\n');
+ }
+
+ std::unique_ptr<Formatter> Clone() const override { return std::make_unique<MessageOnlyFormatter>(); }
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/msvcsink.h b/src/zencore/include/zencore/logging/msvcsink.h
new file mode 100644
index 000000000..48ea1b915
--- /dev/null
+++ b/src/zencore/include/zencore/logging/msvcsink.h
@@ -0,0 +1,30 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/sink.h>
+
+#if ZEN_PLATFORM_WINDOWS
+
+# include <mutex>
+
+namespace zen::logging {
+
+class MsvcSink : public Sink
+{
+public:
+ MsvcSink();
+ ~MsvcSink() override = default;
+
+ void Log(const LogMessage& Msg) override;
+ void Flush() override;
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter) override;
+
+private:
+ std::mutex m_Mutex;
+ std::unique_ptr<Formatter> m_Formatter;
+};
+
+} // namespace zen::logging
+
+#endif // ZEN_PLATFORM_WINDOWS
diff --git a/src/zencore/include/zencore/logging/nullsink.h b/src/zencore/include/zencore/logging/nullsink.h
new file mode 100644
index 000000000..7ac5677c6
--- /dev/null
+++ b/src/zencore/include/zencore/logging/nullsink.h
@@ -0,0 +1,17 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/sink.h>
+
+namespace zen::logging {
+
+class NullSink : public Sink
+{
+public:
+ void Log(const LogMessage& /*Msg*/) override {}
+ void Flush() override {}
+ void SetFormatter(std::unique_ptr<Formatter> /*InFormatter*/) override {}
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/registry.h b/src/zencore/include/zencore/logging/registry.h
new file mode 100644
index 000000000..a4d3692d2
--- /dev/null
+++ b/src/zencore/include/zencore/logging/registry.h
@@ -0,0 +1,70 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/logger.h>
+
+#include <chrono>
+#include <memory>
+#include <span>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+namespace zen::logging {
+
+class Registry
+{
+public:
+ using LogLevels = std::span<const std::pair<std::string, LogLevel>>;
+
+ static Registry& Instance();
+ void Shutdown();
+
+ void Register(Ref<Logger> InLogger);
+ void Drop(const std::string& Name);
+ Ref<Logger> Get(const std::string& Name);
+
+ void SetDefaultLogger(Ref<Logger> InLogger);
+ Logger* DefaultLoggerRaw();
+ Ref<Logger> DefaultLogger();
+
+ void SetGlobalLevel(LogLevel Level);
+ LogLevel GetGlobalLevel() const;
+ void SetLevels(LogLevels Levels, LogLevel* DefaultLevel);
+
+ void FlushAll();
+ void FlushOn(LogLevel Level);
+ void FlushEvery(std::chrono::seconds Interval);
+
+ // Change formatter on all registered loggers
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter);
+
+ // Apply function to all registered loggers. Note that the function will
+ // be called while the registry mutex is held, so it should be fast and
+ // not attempt to call back into the registry.
+ template<typename Func>
+ void ApplyAll(Func&& F)
+ {
+ ApplyAllImpl([](void* Ctx, Ref<Logger> L) { (*static_cast<std::decay_t<Func>*>(Ctx))(std::move(L)); }, &F);
+ }
+
+ // Set error handler for all loggers in the registry. The handler is called
+ // if any logger encounters an error during logging or flushing.
+ // The caller must ensure the handler outlives the registry.
+ void SetErrorHandler(ErrorHandler* Handler);
+
+private:
+ void ApplyAllImpl(void (*Func)(void*, Ref<Logger>), void* Context);
+
+ Registry();
+ ~Registry();
+
+ Registry(const Registry&) = delete;
+ Registry& operator=(const Registry&) = delete;
+
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/sink.h b/src/zencore/include/zencore/logging/sink.h
new file mode 100644
index 000000000..172176a4e
--- /dev/null
+++ b/src/zencore/include/zencore/logging/sink.h
@@ -0,0 +1,34 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/refcount.h>
+#include <zencore/logging/formatter.h>
+#include <zencore/logging/logmsg.h>
+
+#include <atomic>
+#include <memory>
+
+namespace zen::logging {
+
+class Sink : public RefCounted
+{
+public:
+ virtual ~Sink() = default;
+
+ virtual void Log(const LogMessage& Msg) = 0;
+ virtual void Flush() = 0;
+
+ virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) = 0;
+
+ bool ShouldLog(LogLevel InLevel) const { return InLevel >= m_Level.load(std::memory_order_relaxed); }
+ void SetLevel(LogLevel InLevel) { m_Level.store(InLevel, std::memory_order_relaxed); }
+ LogLevel GetLevel() const { return m_Level.load(std::memory_order_relaxed); }
+
+protected:
+ std::atomic<LogLevel> m_Level{Trace};
+};
+
+using SinkPtr = Ref<Sink>;
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/tracesink.h b/src/zencore/include/zencore/logging/tracesink.h
new file mode 100644
index 000000000..785c51e10
--- /dev/null
+++ b/src/zencore/include/zencore/logging/tracesink.h
@@ -0,0 +1,27 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging/sink.h>
+
+namespace zen::logging {
+
+#if ZEN_WITH_TRACE
+
+/**
+ * A logging sink that forwards log messages to the trace system.
+ *
+ * Work-in-progress, not fully implemented.
+ */
+
+class TraceSink : public Sink
+{
+public:
+ void Log(const LogMessage& Msg) override;
+ void Flush() override;
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter) override;
+};
+
+#endif
+
+} // namespace zen::logging
diff --git a/src/zencore/include/zencore/md5.h b/src/zencore/include/zencore/md5.h
index d934dd86b..3b0b7cae6 100644
--- a/src/zencore/include/zencore/md5.h
+++ b/src/zencore/include/zencore/md5.h
@@ -43,6 +43,8 @@ public:
MD5 GetHash();
private:
+ // Opaque storage for MD5_CTX (104 bytes, aligned to uint32_t)
+ alignas(4) uint8_t m_Context[104];
};
void md5_forcelink(); // internal
diff --git a/src/zencore/include/zencore/meta.h b/src/zencore/include/zencore/meta.h
index 82eb5cc30..20ec4ac6f 100644
--- a/src/zencore/include/zencore/meta.h
+++ b/src/zencore/include/zencore/meta.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
/* This file contains utility functions for meta programming
*
diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h
index 19e410d85..d97c433fd 100644
--- a/src/zencore/include/zencore/mpscqueue.h
+++ b/src/zencore/include/zencore/mpscqueue.h
@@ -22,10 +22,10 @@ namespace zen {
template<typename ElementType>
struct TypeCompatibleStorage
{
- ElementType* Data() { return (ElementType*)this; }
- const ElementType* Data() const { return (const ElementType*)this; }
+ ElementType* Data() { return reinterpret_cast<ElementType*>(&Storage); }
+ const ElementType* Data() const { return reinterpret_cast<const ElementType*>(&Storage); }
- alignas(ElementType) char DataMember;
+ alignas(ElementType) char Storage[sizeof(ElementType)];
};
/** Fast multi-producer/single-consumer unbounded concurrent queue.
@@ -58,7 +58,7 @@ public:
Tail = Next;
Next = Tail->Next.load(std::memory_order_relaxed);
- std::destroy_at((ElementType*)&Tail->Value);
+ std::destroy_at(Tail->Value.Data());
delete Tail;
}
}
@@ -67,7 +67,7 @@ public:
void Enqueue(ArgTypes&&... Args)
{
Node* New = new Node;
- new (&New->Value) ElementType(std::forward<ArgTypes>(Args)...);
+ new (New->Value.Data()) ElementType(std::forward<ArgTypes>(Args)...);
Node* Prev = Head.exchange(New, std::memory_order_acq_rel);
Prev->Next.store(New, std::memory_order_release);
@@ -82,7 +82,7 @@ public:
return {};
}
- ElementType* ValuePtr = (ElementType*)&Next->Value;
+ ElementType* ValuePtr = Next->Value.Data();
std::optional<ElementType> Res{std::move(*ValuePtr)};
std::destroy_at(ValuePtr);
@@ -100,9 +100,11 @@ private:
};
private:
- std::atomic<Node*> Head; // accessed only by producers
- alignas(hardware_constructive_interference_size)
- Node* Tail; // accessed only by consumer, hence should be on a different cache line than `Head`
+ // Use a fixed constant to avoid GCC's -Winterference-size warning with std::hardware_destructive_interference_size
+ static constexpr std::size_t CacheLineSize = 64;
+
+ alignas(CacheLineSize) std::atomic<Node*> Head; // accessed only by producers
+ alignas(CacheLineSize) Node* Tail; // accessed only by consumer, separate cache line from Head
};
void mpscqueue_forcelink();
diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h
index e3b7a70d7..809312c7b 100644
--- a/src/zencore/include/zencore/process.h
+++ b/src/zencore/include/zencore/process.h
@@ -9,6 +9,10 @@
namespace zen {
+#if ZEN_PLATFORM_WINDOWS
+class JobObject;
+#endif
+
/** Basic process abstraction
*/
class ProcessHandle
@@ -46,6 +50,7 @@ private:
/** Basic process creation
*/
+
struct CreateProcOptions
{
enum
@@ -63,6 +68,9 @@ struct CreateProcOptions
const std::filesystem::path* WorkingDirectory = nullptr;
uint32_t Flags = 0;
std::filesystem::path StdoutFile;
+#if ZEN_PLATFORM_WINDOWS
+ JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed
+#endif
};
#if ZEN_PLATFORM_WINDOWS
@@ -99,12 +107,38 @@ private:
std::vector<HandleType> m_ProcessHandles;
};
+#if ZEN_PLATFORM_WINDOWS
+/** Windows Job Object wrapper
+ *
+ * When configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, the OS will
+ * terminate all assigned child processes when the job handle is closed
+ * (including abnormal termination of the owning process). This provides
+ * an OS-level guarantee against orphaned child processes.
+ */
+class JobObject
+{
+public:
+ JobObject();
+ ~JobObject();
+ JobObject(const JobObject&) = delete;
+ JobObject& operator=(const JobObject&) = delete;
+
+ void Initialize();
+ bool AssignProcess(void* ProcessHandle);
+ [[nodiscard]] bool IsValid() const;
+
+private:
+ void* m_JobHandle = nullptr;
+};
+#endif // ZEN_PLATFORM_WINDOWS
+
bool IsProcessRunning(int pid);
bool IsProcessRunning(int pid, std::error_code& OutEc);
int GetCurrentProcessId();
int GetProcessId(CreateProcResult ProcId);
std::filesystem::path GetProcessExecutablePath(int Pid, std::error_code& OutEc);
+std::string GetProcessCommandLine(int Pid, std::error_code& OutEc);
std::error_code FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf = true);
/** Wait for all threads in the current process to exit (except the calling thread)
diff --git a/src/zencore/include/zencore/sentryintegration.h b/src/zencore/include/zencore/sentryintegration.h
index faf1238b7..27e5a8a82 100644
--- a/src/zencore/include/zencore/sentryintegration.h
+++ b/src/zencore/include/zencore/sentryintegration.h
@@ -11,11 +11,9 @@
#if ZEN_USE_SENTRY
-# include <memory>
+# include <zencore/logging/logger.h>
-ZEN_THIRD_PARTY_INCLUDES_START
-# include <spdlog/logger.h>
-ZEN_THIRD_PARTY_INCLUDES_END
+# include <memory>
namespace sentry {
@@ -42,6 +40,7 @@ public:
};
void Initialize(const Config& Conf, const std::string& CommandLine);
+ void Close();
void LogStartupInformation();
static void ClearCaches();
@@ -53,7 +52,7 @@ private:
std::string m_SentryUserName;
std::string m_SentryHostName;
std::string m_SentryId;
- std::shared_ptr<spdlog::logger> m_SentryLogger;
+ Ref<logging::Logger> m_SentryLogger;
};
} // namespace zen
diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h
index c57e9f568..3d4c19282 100644
--- a/src/zencore/include/zencore/sharedbuffer.h
+++ b/src/zencore/include/zencore/sharedbuffer.h
@@ -116,14 +116,15 @@ public:
inline void Reset() { m_Buffer = nullptr; }
inline bool GetFileReference(IoBufferFileReference& OutRef) const
{
- if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore())
+ if (!IsNull())
{
- return Core->GetFileReference(OutRef);
- }
- else
- {
- return false;
+ if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore())
+ {
+ return Core->GetFileReference(OutRef);
+ }
}
+
+ return false;
}
[[nodiscard]] MemoryView GetView() const
diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h
index cbff6454f..4deca63ed 100644
--- a/src/zencore/include/zencore/string.h
+++ b/src/zencore/include/zencore/string.h
@@ -8,7 +8,6 @@
#include <stdint.h>
#include <string.h>
#include <charconv>
-#include <codecvt>
#include <compare>
#include <concepts>
#include <optional>
@@ -51,7 +50,7 @@ StringLength(const wchar_t* str)
return wcslen(str);
}
-inline bool
+inline int
StringCompare(const char16_t* s1, const char16_t* s2)
{
char16_t c1, c2;
@@ -66,7 +65,7 @@ StringCompare(const char16_t* s1, const char16_t* s2)
++s1;
++s2;
}
- return uint16_t(c1) - uint16_t(c2);
+ return int(uint16_t(c1)) - int(uint16_t(c2));
}
inline bool
@@ -122,10 +121,10 @@ public:
StringBuilderImpl() = default;
~StringBuilderImpl();
- StringBuilderImpl(const StringBuilderImpl&) = delete;
- StringBuilderImpl(const StringBuilderImpl&&) = delete;
+ StringBuilderImpl(const StringBuilderImpl&) = delete;
+ StringBuilderImpl(StringBuilderImpl&&) = delete;
const StringBuilderImpl& operator=(const StringBuilderImpl&) = delete;
- const StringBuilderImpl& operator=(const StringBuilderImpl&&) = delete;
+ StringBuilderImpl& operator=(StringBuilderImpl&&) = delete;
inline size_t AddUninitialized(size_t Count)
{
@@ -374,9 +373,9 @@ protected:
[[noreturn]] void Fail(const char* FailReason); // note: throws exception
- C* m_Base;
- C* m_CurPos;
- C* m_End;
+ C* m_Base = nullptr;
+ C* m_CurPos = nullptr;
+ C* m_End = nullptr;
bool m_IsDynamic = false;
bool m_IsExtendable = false;
};
@@ -773,8 +772,9 @@ std::optional<T>
ParseInt(const std::string_view& Input)
{
T Out = 0;
- const std::from_chars_result Result = std::from_chars(Input.data(), Input.data() + Input.size(), Out);
- if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range)
+ const char* End = Input.data() + Input.size();
+ const std::from_chars_result Result = std::from_chars(Input.data(), End, Out);
+ if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range || Result.ptr != End)
{
return std::nullopt;
}
@@ -797,6 +797,22 @@ HashStringDjb2(const std::string_view& InString)
}
constexpr uint32_t
+HashStringDjb2(const std::span<const std::string_view> InStrings)
+{
+ uint32_t HashValue = 5381;
+
+ for (const std::string_view& String : InStrings)
+ {
+ for (int CurChar : String)
+ {
+ HashValue = HashValue * 33 + CurChar;
+ }
+ }
+
+ return HashValue;
+}
+
+constexpr uint32_t
HashStringAsLowerDjb2(const std::string_view& InString)
{
uint32_t HashValue = 5381;
@@ -1249,6 +1265,8 @@ private:
uint64_t LoMask, HiMask;
};
+std::string HideSensitiveString(std::string_view String);
+
//////////////////////////////////////////////////////////////////////////
void string_forcelink(); // internal
diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h
index aec2e0ce4..a67999e52 100644
--- a/src/zencore/include/zencore/system.h
+++ b/src/zencore/include/zencore/system.h
@@ -4,6 +4,8 @@
#include <zencore/zencore.h>
+#include <chrono>
+#include <memory>
#include <string>
namespace zen {
@@ -12,6 +14,8 @@ class CbWriter;
std::string GetMachineName();
std::string_view GetOperatingSystemName();
+std::string GetOperatingSystemVersion();
+std::string_view GetRuntimePlatformName(); // "windows", "wine", "linux", or "macos"
std::string_view GetCpuName();
struct SystemMetrics
@@ -25,6 +29,14 @@ struct SystemMetrics
uint64_t AvailVirtualMemoryMiB = 0;
uint64_t PageFileMiB = 0;
uint64_t AvailPageFileMiB = 0;
+ uint64_t UptimeSeconds = 0;
+};
+
+/// Extended metrics that include CPU usage percentage, which requires
+/// stateful delta tracking via SystemMetricsTracker.
+struct ExtendedSystemMetrics : SystemMetrics
+{
+ float CpuUsagePercent = 0.0f;
};
SystemMetrics GetSystemMetrics();
@@ -32,6 +44,31 @@ SystemMetrics GetSystemMetrics();
void SetCpuCountForReporting(int FakeCpuCount);
SystemMetrics GetSystemMetricsForReporting();
+ExtendedSystemMetrics ApplyReportingOverrides(ExtendedSystemMetrics Metrics);
+
void Describe(const SystemMetrics& Metrics, CbWriter& Writer);
+void Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer);
+
+/// Stateful tracker that computes CPU usage as a delta between consecutive
+/// Query() calls. The first call returns CpuUsagePercent = 0 (no previous
+/// sample). Thread-safe: concurrent calls are serialised internally.
+/// CPU sampling is rate-limited to MinInterval (default 1 s); calls that
+/// arrive sooner return the previously cached value.
+class SystemMetricsTracker
+{
+public:
+ explicit SystemMetricsTracker(std::chrono::milliseconds MinInterval = std::chrono::seconds(1));
+ ~SystemMetricsTracker();
+
+ SystemMetricsTracker(const SystemMetricsTracker&) = delete;
+ SystemMetricsTracker& operator=(const SystemMetricsTracker&) = delete;
+
+ /// Collect current metrics. CPU usage is computed as delta since last Query().
+ ExtendedSystemMetrics Query();
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
} // namespace zen
diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h
index a00ee3166..8410216c4 100644
--- a/src/zencore/include/zencore/testing.h
+++ b/src/zencore/include/zencore/testing.h
@@ -43,8 +43,9 @@ public:
TestRunner();
~TestRunner();
- int ApplyCommandLine(int argc, char const* const* argv);
- int Run();
+ void SetDefaultSuiteFilter(const char* Pattern);
+ int ApplyCommandLine(int Argc, char const* const* Argv);
+ int Run();
private:
struct Impl;
@@ -59,6 +60,8 @@ private:
return Runner.Run(); \
}()
+int RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink)());
+
} // namespace zen::testing
#endif
diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h
index e2a4f8346..2a789d18f 100644
--- a/src/zencore/include/zencore/testutils.h
+++ b/src/zencore/include/zencore/testutils.h
@@ -59,6 +59,33 @@ struct TrueType
static const bool Enabled = true;
};
+namespace utf8test {
+
+ // 2-byte UTF-8 (Latin extended)
+ static constexpr const char kLatin[] = u8"café_résumé";
+ static constexpr const wchar_t kLatinW[] = L"café_résumé";
+
+ // 2-byte UTF-8 (Cyrillic)
+ static constexpr const char kCyrillic[] = u8"данные";
+ static constexpr const wchar_t kCyrillicW[] = L"данные";
+
+ // 3-byte UTF-8 (CJK)
+ static constexpr const char kCJK[] = u8"日本語";
+ static constexpr const wchar_t kCJKW[] = L"日本語";
+
+ // Mixed scripts
+ static constexpr const char kMixed[] = u8"zen_éд日";
+ static constexpr const wchar_t kMixedW[] = L"zen_éд日";
+
+ // 4-byte UTF-8 (supplementary plane) — string tests only, NOT filesystem
+ static constexpr const char kEmoji[] = u8"📦";
+ static constexpr const wchar_t kEmojiW[] = L"📦";
+
+ // BMP-only test strings suitable for filesystem use
+ static constexpr const char* kFilenameSafe[] = {kLatin, kCyrillic, kCJK, kMixed};
+
+} // namespace utf8test
+
} // namespace zen
#endif // ZEN_WITH_TESTS
diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h
index de8f9399c..d0d710ee8 100644
--- a/src/zencore/include/zencore/thread.h
+++ b/src/zencore/include/zencore/thread.h
@@ -58,17 +58,27 @@ public:
}
private:
- RwLock* m_Lock;
+ RwLock* m_Lock = nullptr;
};
- inline void WithSharedLock(auto&& Fun)
+ inline auto WithSharedLock(auto&& Fun)
{
SharedLockScope $(*this);
- Fun();
+ return Fun();
}
struct ExclusiveLockScope
{
+ ExclusiveLockScope(const ExclusiveLockScope& Rhs) = delete;
+ ExclusiveLockScope(ExclusiveLockScope&& Rhs) : m_Lock(Rhs.m_Lock) { Rhs.m_Lock = nullptr; }
+ ExclusiveLockScope& operator=(ExclusiveLockScope&& Rhs)
+ {
+ ReleaseNow();
+ m_Lock = Rhs.m_Lock;
+ Rhs.m_Lock = nullptr;
+ return *this;
+ }
+ ExclusiveLockScope& operator=(const ExclusiveLockScope& Rhs) = delete;
ExclusiveLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireExclusive(); }
~ExclusiveLockScope() { ReleaseNow(); }
@@ -82,13 +92,13 @@ public:
}
private:
- RwLock* m_Lock;
+ RwLock* m_Lock = nullptr;
};
- inline void WithExclusiveLock(auto&& Fun)
+ inline auto WithExclusiveLock(auto&& Fun)
{
ExclusiveLockScope $(*this);
- Fun();
+ return Fun();
}
private:
@@ -195,7 +205,7 @@ public:
// false positive completion results.
void AddCount(std::ptrdiff_t Count)
{
- std::atomic_ptrdiff_t Old = Counter.fetch_add(Count);
+ std::ptrdiff_t Old = Counter.fetch_add(Count);
ZEN_ASSERT(Old > 0);
}
diff --git a/src/zencore/include/zencore/trace.h b/src/zencore/include/zencore/trace.h
index 99a565151..d17e018ea 100644
--- a/src/zencore/include/zencore/trace.h
+++ b/src/zencore/include/zencore/trace.h
@@ -13,6 +13,7 @@ ZEN_THIRD_PARTY_INCLUDES_START
# define TRACE_IMPLEMENT 0
#endif
#include <trace.h>
+#include <lane_trace.h>
#undef TRACE_IMPLEMENT
ZEN_THIRD_PARTY_INCLUDES_END
diff --git a/src/zencore/include/zencore/varint.h b/src/zencore/include/zencore/varint.h
index 9fe905f25..43ca14d38 100644
--- a/src/zencore/include/zencore/varint.h
+++ b/src/zencore/include/zencore/varint.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include "intmath.h"
diff --git a/src/zencore/include/zencore/xxhash.h b/src/zencore/include/zencore/xxhash.h
index fc55b513b..f79d39b61 100644
--- a/src/zencore/include/zencore/xxhash.h
+++ b/src/zencore/include/zencore/xxhash.h
@@ -87,7 +87,7 @@ struct XXH3_128Stream
}
private:
- XXH3_state_s m_State;
+ XXH3_state_s m_State{};
};
struct XXH3_128Stream_deprecated
diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h
index 177a19fff..a31950b0b 100644
--- a/src/zencore/include/zencore/zencore.h
+++ b/src/zencore/include/zencore/zencore.h
@@ -70,26 +70,36 @@ protected:
} // namespace zen
-#define ZEN_ASSERT(x, ...) \
- do \
- { \
- if (x) [[unlikely]] \
- break; \
- zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \
+#define ZEN_ASSERT(x, ...) \
+ do \
+ { \
+ if (x) [[unlikely]] \
+ break; \
+ zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, ZEN_ASSERT_MSG_(#x, ##__VA_ARGS__)); \
} while (false)
#ifndef NDEBUG
-# define ZEN_ASSERT_SLOW(x, ...) \
- do \
- { \
- if (x) [[unlikely]] \
- break; \
- zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \
+# define ZEN_ASSERT_SLOW(x, ...) \
+ do \
+ { \
+ if (x) [[unlikely]] \
+ break; \
+ zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, ZEN_ASSERT_MSG_(#x, ##__VA_ARGS__)); \
} while (false)
#else
# define ZEN_ASSERT_SLOW(x, ...)
#endif
+// Internal: select between "expr" and "expr: message" forms.
+// With no extra args: ZEN_ASSERT_MSG_("expr") -> "expr"
+// With a message arg: ZEN_ASSERT_MSG_("expr", "msg") -> "expr" ": " "msg"
+// With fmt-style args: ZEN_ASSERT_MSG_("expr", "msg", args...) -> "expr" ": " "msg"
+// The extra fmt args are silently discarded here — use ZEN_ASSERT_FORMAT for those.
+#define ZEN_ASSERT_MSG_SELECT_(_1, _2, N, ...) N
+#define ZEN_ASSERT_MSG_1_(expr) expr
+#define ZEN_ASSERT_MSG_2_(expr, msg, ...) expr ": " msg
+#define ZEN_ASSERT_MSG_(expr, ...) ZEN_ASSERT_MSG_SELECT_(unused, ##__VA_ARGS__, ZEN_ASSERT_MSG_2_, ZEN_ASSERT_MSG_1_)(expr, ##__VA_ARGS__)
+
//////////////////////////////////////////////////////////////////////////
#define ZEN_NOT_IMPLEMENTED(...) ZEN_ASSERT(false, __VA_ARGS__)
diff --git a/src/zencore/intmath.cpp b/src/zencore/intmath.cpp
index 5a686dc8e..fedf76edc 100644
--- a/src/zencore/intmath.cpp
+++ b/src/zencore/intmath.cpp
@@ -19,6 +19,8 @@ intmath_forcelink()
{
}
+TEST_SUITE_BEGIN("core.intmath");
+
TEST_CASE("intmath")
{
CHECK(FloorLog2(0x00) == 0);
@@ -43,6 +45,12 @@ TEST_CASE("intmath")
CHECK(FloorLog2_64(0x0000'0001'0000'0000ull) == 32);
CHECK(FloorLog2_64(0x8000'0000'0000'0000ull) == 63);
+ CHECK(CountLeadingZeros(0x8000'0000u) == 0);
+ CHECK(CountLeadingZeros(0x0000'0000u) == 32);
+ CHECK(CountLeadingZeros(0x0000'0001u) == 31);
+ CHECK(CountLeadingZeros(0x0000'8000u) == 16);
+ CHECK(CountLeadingZeros(0x0001'0000u) == 15);
+
CHECK(CountLeadingZeros64(0x8000'0000'0000'0000ull) == 0);
CHECK(CountLeadingZeros64(0x0000'0000'0000'0000ull) == 64);
CHECK(CountLeadingZeros64(0x0000'0000'0000'0001ull) == 63);
@@ -60,6 +68,8 @@ TEST_CASE("intmath")
CHECK(ByteSwap(uint64_t(0x214d'6172'7469'6e21ull)) == 0x216e'6974'7261'4d21ull);
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp
index be9b39e7a..c47c54981 100644
--- a/src/zencore/iobuffer.cpp
+++ b/src/zencore/iobuffer.cpp
@@ -592,15 +592,17 @@ IoBufferBuilder::ReadFromFileMaybe(const IoBuffer& InBuffer)
}
IoBuffer
-IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size)
+IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size, ZenContentType ContentType)
{
ZEN_TRACE_CPU("IoBufferBuilder::MakeFromFileHandle");
- return IoBuffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size);
+ IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size);
+ Buffer.SetContentType(ContentType);
+ return Buffer;
}
IoBuffer
-IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size)
+IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size, ZenContentType ContentType)
{
ZEN_TRACE_CPU("IoBufferBuilder::MakeFromFile");
@@ -632,8 +634,6 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of
FileSize = Stat.st_size;
#endif // ZEN_PLATFORM_WINDOWS
- // TODO: should validate that offset is in range
-
if (Size == ~0ull)
{
Size = FileSize - Offset;
@@ -652,7 +652,9 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of
#if ZEN_PLATFORM_WINDOWS
void* Fd = DataFile.Detach();
#endif
- return IoBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize);
+ IoBuffer NewBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize);
+ NewBuffer.SetContentType(ContentType);
+ return NewBuffer;
}
#if !ZEN_PLATFORM_WINDOWS
@@ -664,7 +666,7 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of
}
IoBuffer
-IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName)
+IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName, ZenContentType ContentType)
{
ZEN_TRACE_CPU("IoBufferBuilder::MakeFromTemporaryFile");
@@ -703,7 +705,9 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName)
Handle = (void*)uintptr_t(Fd);
#endif // ZEN_PLATFORM_WINDOWS
- return IoBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true);
+ IoBuffer NewBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true);
+ NewBuffer.SetContentType(ContentType);
+ return NewBuffer;
}
//////////////////////////////////////////////////////////////////////////
@@ -715,6 +719,8 @@ iobuffer_forcelink()
{
}
+TEST_SUITE_BEGIN("core.iobuffer");
+
TEST_CASE("IoBuffer")
{
zen::IoBuffer buffer1;
@@ -752,6 +758,8 @@ TEST_CASE("IoBuffer.mmap")
# endif
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp
index 75c1be42b..d6a8a6479 100644
--- a/src/zencore/jobqueue.cpp
+++ b/src/zencore/jobqueue.cpp
@@ -90,7 +90,7 @@ public:
uint64_t NewJobId = IdGenerator.fetch_add(1);
if (NewJobId == 0)
{
- IdGenerator.fetch_add(1);
+ NewJobId = IdGenerator.fetch_add(1);
}
RefPtr<Job> NewJob(new Job());
NewJob->Queue = this;
@@ -129,7 +129,7 @@ public:
QueuedJobs.erase(It);
}
});
- ZEN_ERROR("Failed to schedule job {}:'{}' to job queue. Reason: ''", NewJob->Id.Id, NewJob->Name, Ex.what());
+ ZEN_ERROR("Failed to schedule job {}:'{}' to job queue. Reason: '{}'", NewJob->Id.Id, NewJob->Name, Ex.what());
throw;
}
}
@@ -221,11 +221,11 @@ public:
std::vector<JobInfo> Jobs;
QueueLock.WithSharedLock([&]() {
- for (auto It : RunningJobs)
+ for (const auto& It : RunningJobs)
{
Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Running});
}
- for (auto It : CompletedJobs)
+ for (const auto& It : CompletedJobs)
{
if (IsStale(It.second->EndTick))
{
@@ -234,7 +234,7 @@ public:
}
Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Completed});
}
- for (auto It : AbortedJobs)
+ for (const auto& It : AbortedJobs)
{
if (IsStale(It.second->EndTick))
{
@@ -243,7 +243,7 @@ public:
}
Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Aborted});
}
- for (auto It : QueuedJobs)
+ for (const auto& It : QueuedJobs)
{
Jobs.push_back({.Id = It->Id, .Status = JobStatus::Queued});
}
@@ -337,7 +337,7 @@ public:
std::atomic_bool InitializedFlag = false;
RwLock QueueLock;
std::deque<RefPtr<Job>> QueuedJobs;
- std::unordered_map<uint64_t, Job*> RunningJobs;
+ std::unordered_map<uint64_t, RefPtr<Job>> RunningJobs;
std::unordered_map<uint64_t, RefPtr<Job>> CompletedJobs;
std::unordered_map<uint64_t, RefPtr<Job>> AbortedJobs;
@@ -429,20 +429,16 @@ JobQueue::ToString(JobStatus Status)
{
case JobQueue::JobStatus::Queued:
return "Queued"sv;
- break;
case JobQueue::JobStatus::Running:
return "Running"sv;
- break;
case JobQueue::JobStatus::Aborted:
return "Aborted"sv;
- break;
case JobQueue::JobStatus::Completed:
return "Completed"sv;
- break;
default:
ZEN_ASSERT(false);
+ return ""sv;
}
- return ""sv;
}
std::unique_ptr<JobQueue>
@@ -460,6 +456,8 @@ jobqueue_forcelink()
{
}
+TEST_SUITE_BEGIN("core.jobqueue");
+
TEST_CASE("JobQueue")
{
std::unique_ptr<JobQueue> Queue(MakeJobQueue(2, "queue"));
@@ -580,6 +578,8 @@ TEST_CASE("JobQueue")
}
JobsLatch.Wait();
}
+
+TEST_SUITE_END();
#endif
} // namespace zen
diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp
index a6697c443..099518637 100644
--- a/src/zencore/logging.cpp
+++ b/src/zencore/logging.cpp
@@ -2,208 +2,128 @@
#include "zencore/logging.h"
+#include <zencore/logging/ansicolorsink.h>
+#include <zencore/logging/logger.h>
+#include <zencore/logging/messageonlyformatter.h>
+#include <zencore/logging/nullsink.h>
+#include <zencore/logging/registry.h>
#include <zencore/string.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
#include <zencore/memory/llm.h>
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/details/registry.h>
-#include <spdlog/sinks/null_sink.h>
-#include <spdlog/sinks/stdout_color_sinks.h>
-#include <spdlog/spdlog.h>
-ZEN_THIRD_PARTY_INCLUDES_END
+#include <mutex>
#if ZEN_PLATFORM_WINDOWS
# pragma section(".zlog$a", read)
-# pragma section(".zlog$f", read)
-# pragma section(".zlog$m", read)
-# pragma section(".zlog$s", read)
+# pragma section(".zlog$l", read)
# pragma section(".zlog$z", read)
#endif
namespace zen {
-// We shadow the underlying spdlog default logger, in order to avoid a bunch of overhead
LoggerRef TheDefaultLogger;
+LoggerRef
+Log()
+{
+ if (TheDefaultLogger)
+ {
+ return TheDefaultLogger;
+ }
+ return zen::logging::ConsoleLog();
+}
+
} // namespace zen
namespace zen::logging {
-using MemoryBuffer_t = fmt::basic_memory_buffer<char, 250>;
-
-struct LoggingContext
-{
- inline LoggingContext();
- inline ~LoggingContext();
-
- zen::logging::MemoryBuffer_t MessageBuffer;
-
- inline std::string_view Message() const { return std::string_view(MessageBuffer.data(), MessageBuffer.size()); }
-};
+//////////////////////////////////////////////////////////////////////////
-LoggingContext::LoggingContext()
+LoggerRef
+LogCategory::Logger()
{
-}
+ // This should be thread safe since zen::logging::Get() will return
+ // the same logger instance for the same category name. Also the
+ // LoggerRef is simply a pointer.
+ if (!m_LoggerRef)
+ {
+ m_LoggerRef = zen::logging::Get(m_CategoryName);
+ }
-LoggingContext::~LoggingContext()
-{
+ return m_LoggerRef;
}
-//////////////////////////////////////////////////////////////////////////
-
static inline bool
-IsErrorLevel(int LogLevel)
+IsErrorLevel(LogLevel InLevel)
{
- return (LogLevel == zen::logging::level::Err || LogLevel == zen::logging::level::Critical);
+ return (InLevel == Err || InLevel == Critical);
};
-static_assert(sizeof(spdlog::source_loc) == sizeof(SourceLocation));
-static_assert(offsetof(spdlog::source_loc, filename) == offsetof(SourceLocation, filename));
-static_assert(offsetof(spdlog::source_loc, line) == offsetof(SourceLocation, line));
-static_assert(offsetof(spdlog::source_loc, funcname) == offsetof(SourceLocation, funcname));
-
void
-EmitLogMessage(LoggerRef& Logger, int LogLevel, const std::string_view Message)
+EmitLogMessage(LoggerRef& Logger, const LogPoint& Lp, fmt::format_args Args)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel;
- Logger.SpdLogger->log(InLevel, Message);
- if (IsErrorLevel(LogLevel))
- {
- if (LoggerRef ErrLogger = zen::logging::ErrorLog())
- {
- ErrLogger.SpdLogger->log(InLevel, Message);
- }
- }
-}
-void
-EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args)
-{
- ZEN_MEMSCOPE(ELLMTag::Logging);
- zen::logging::LoggingContext LogCtx;
- fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args);
- zen::logging::EmitLogMessage(Logger, LogLevel, LogCtx.Message());
-}
+ Logger->Log(Lp, Args);
-void
-EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, const std::string_view Message)
-{
- ZEN_MEMSCOPE(ELLMTag::Logging);
- const spdlog::source_loc& Location = *reinterpret_cast<const spdlog::source_loc*>(&InLocation);
- const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel;
- Logger.SpdLogger->log(Location, InLevel, Message);
- if (IsErrorLevel(LogLevel))
+ if (IsErrorLevel(Lp.Level))
{
if (LoggerRef ErrLogger = zen::logging::ErrorLog())
{
- ErrLogger.SpdLogger->log(Location, InLevel, Message);
+ ErrLogger->Log(Lp, Args);
}
}
}
void
-EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, std::string_view Format, fmt::format_args Args)
+EmitConsoleLogMessage(const LogPoint& Lp, fmt::format_args Args)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- zen::logging::LoggingContext LogCtx;
- fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args);
- zen::logging::EmitLogMessage(Logger, InLocation, LogLevel, LogCtx.Message());
-}
-
-void
-EmitConsoleLogMessage(int LogLevel, const std::string_view Message)
-{
- ZEN_MEMSCOPE(ELLMTag::Logging);
- const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel;
- ConsoleLog().SpdLogger->log(InLevel, Message);
-}
-
-#define ZEN_COLOR_YELLOW "\033[0;33m"
-#define ZEN_COLOR_RED "\033[0;31m"
-#define ZEN_BRIGHT_COLOR_RED "\033[1;31m"
-#define ZEN_COLOR_RESET "\033[0m"
-
-void
-EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args)
-{
- ZEN_MEMSCOPE(ELLMTag::Logging);
- zen::logging::LoggingContext LogCtx;
-
- // We are not using a format option for console which include log level since it would interfere with normal console output
-
- const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel;
- switch (InLevel)
- {
- case spdlog::level::level_enum::warn:
- fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET);
- break;
- case spdlog::level::level_enum::err:
- fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET);
- break;
- case spdlog::level::level_enum::critical:
- fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET);
- break;
- default:
- break;
- }
- fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args);
- zen::logging::EmitConsoleLogMessage(LogLevel, LogCtx.Message());
+ ConsoleLog()->Log(Lp, Args);
}
} // namespace zen::logging
-namespace zen::logging::level {
+namespace zen::logging {
-spdlog::level::level_enum
-to_spdlog_level(LogLevel NewLogLevel)
-{
- return static_cast<spdlog::level::level_enum>((int)NewLogLevel);
-}
+constinit std::string_view LevelNames[] = {std::string_view("trace", 5),
+ std::string_view("debug", 5),
+ std::string_view("info", 4),
+ std::string_view("warning", 7),
+ std::string_view("error", 5),
+ std::string_view("critical", 8),
+ std::string_view("off", 3)};
LogLevel
-to_logging_level(spdlog::level::level_enum NewLogLevel)
-{
- return static_cast<LogLevel>((int)NewLogLevel);
-}
-
-constinit std::string_view LevelNames[] = {ZEN_LEVEL_NAME_TRACE,
- ZEN_LEVEL_NAME_DEBUG,
- ZEN_LEVEL_NAME_INFO,
- ZEN_LEVEL_NAME_WARNING,
- ZEN_LEVEL_NAME_ERROR,
- ZEN_LEVEL_NAME_CRITICAL,
- ZEN_LEVEL_NAME_OFF};
-
-level::LogLevel
ParseLogLevelString(std::string_view Name)
{
- for (int Level = 0; Level < level::LogLevelCount; ++Level)
+ for (int Level = 0; Level < LogLevelCount; ++Level)
{
if (LevelNames[Level] == Name)
- return static_cast<level::LogLevel>(Level);
+ {
+ return static_cast<LogLevel>(Level);
+ }
}
if (Name == "warn")
{
- return level::Warn;
+ return Warn;
}
if (Name == "err")
{
- return level::Err;
+ return Err;
}
- return level::Off;
+ return Off;
}
std::string_view
-ToStringView(level::LogLevel Level)
+ToStringView(LogLevel Level)
{
- if (int(Level) < level::LogLevelCount)
+ if (int(Level) < LogLevelCount)
{
return LevelNames[int(Level)];
}
@@ -211,17 +131,17 @@ ToStringView(level::LogLevel Level)
return "None";
}
-} // namespace zen::logging::level
+} // namespace zen::logging
//////////////////////////////////////////////////////////////////////////
namespace zen::logging {
RwLock LogLevelsLock;
-std::string LogLevels[level::LogLevelCount];
+std::string LogLevels[LogLevelCount];
void
-ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers)
+ConfigureLogLevels(LogLevel Level, std::string_view Loggers)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
@@ -230,18 +150,18 @@ ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers)
}
void
-RefreshLogLevels(level::LogLevel* DefaultLevel)
+RefreshLogLevels(LogLevel* DefaultLevel)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- spdlog::details::registry::log_levels Levels;
+ std::vector<std::pair<std::string, LogLevel>> Levels;
{
RwLock::SharedLockScope _(LogLevelsLock);
- for (int i = 0; i < level::LogLevelCount; ++i)
+ for (int i = 0; i < LogLevelCount; ++i)
{
- level::LogLevel CurrentLevel{i};
+ LogLevel CurrentLevel{i};
std::string_view Spec = LogLevels[i];
@@ -251,7 +171,7 @@ RefreshLogLevels(level::LogLevel* DefaultLevel)
if (auto CommaPos = Spec.find_first_of(','); CommaPos != std::string_view::npos)
{
- LoggerName = Spec.substr(CommaPos + 1);
+ LoggerName = Spec.substr(0, CommaPos);
Spec.remove_prefix(CommaPos + 1);
}
else
@@ -260,24 +180,16 @@ RefreshLogLevels(level::LogLevel* DefaultLevel)
Spec = {};
}
- Levels[LoggerName] = to_spdlog_level(CurrentLevel);
+ Levels.emplace_back(std::move(LoggerName), CurrentLevel);
}
}
}
- if (DefaultLevel)
- {
- spdlog::level::level_enum SpdDefaultLevel = to_spdlog_level(*DefaultLevel);
- spdlog::details::registry::instance().set_levels(Levels, &SpdDefaultLevel);
- }
- else
- {
- spdlog::details::registry::instance().set_levels(Levels, nullptr);
- }
+ Registry::Instance().SetLevels(Levels, DefaultLevel);
}
void
-RefreshLogLevels(level::LogLevel DefaultLevel)
+RefreshLogLevels(LogLevel DefaultLevel)
{
RefreshLogLevels(&DefaultLevel);
}
@@ -289,21 +201,21 @@ RefreshLogLevels()
}
void
-SetLogLevel(level::LogLevel NewLogLevel)
+SetLogLevel(LogLevel NewLogLevel)
{
- spdlog::set_level(to_spdlog_level(NewLogLevel));
+ Registry::Instance().SetGlobalLevel(NewLogLevel);
}
-level::LogLevel
+LogLevel
GetLogLevel()
{
- return level::to_logging_level(spdlog::get_level());
+ return Registry::Instance().GetGlobalLevel();
}
LoggerRef
Default()
{
- ZEN_ASSERT(TheDefaultLogger);
+ ZEN_ASSERT(TheDefaultLogger, "logging::InitializeLogging() must be called before using the logger");
return TheDefaultLogger;
}
@@ -312,10 +224,10 @@ SetDefault(std::string_view NewDefaultLoggerId)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- auto NewDefaultLogger = spdlog::get(std::string(NewDefaultLoggerId));
+ Ref<Logger> NewDefaultLogger = Registry::Instance().Get(std::string(NewDefaultLoggerId));
ZEN_ASSERT(NewDefaultLogger);
- spdlog::set_default_logger(NewDefaultLogger);
+ Registry::Instance().SetDefaultLogger(NewDefaultLogger);
TheDefaultLogger = LoggerRef(*NewDefaultLogger);
}
@@ -338,11 +250,11 @@ SetErrorLog(std::string_view NewErrorLoggerId)
}
else
{
- auto NewErrorLogger = spdlog::get(std::string(NewErrorLoggerId));
+ Ref<Logger> NewErrorLogger = Registry::Instance().Get(std::string(NewErrorLoggerId));
ZEN_ASSERT(NewErrorLogger);
- TheErrorLogger = LoggerRef(*NewErrorLogger.get());
+ TheErrorLogger = LoggerRef(*NewErrorLogger.Get());
}
}
@@ -353,39 +265,75 @@ Get(std::string_view Name)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- std::shared_ptr<spdlog::logger> Logger = spdlog::get(std::string(Name));
+ Ref<Logger> FoundLogger = Registry::Instance().Get(std::string(Name));
- if (!Logger)
+ if (!FoundLogger)
{
g_LoggerMutex.WithExclusiveLock([&] {
- Logger = spdlog::get(std::string(Name));
+ FoundLogger = Registry::Instance().Get(std::string(Name));
- if (!Logger)
+ if (!FoundLogger)
{
- Logger = Default().SpdLogger->clone(std::string(Name));
- spdlog::apply_logger_env_levels(Logger);
- spdlog::register_logger(Logger);
+ FoundLogger = Default()->Clone(std::string(Name));
+ Registry::Instance().Register(FoundLogger);
}
});
}
- return *Logger;
+ return *FoundLogger;
}
-std::once_flag ConsoleInitFlag;
-std::shared_ptr<spdlog::logger> ConLogger;
+std::once_flag ConsoleInitFlag;
+Ref<Logger> ConLogger;
void
SuppressConsoleLog()
{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
if (ConLogger)
{
- spdlog::drop("console");
+ Registry::Instance().Drop("console");
ConLogger = {};
}
- ConLogger = spdlog::null_logger_mt("console");
+
+ SinkPtr NullSinkPtr(new NullSink());
+ ConLogger = Ref<Logger>(new Logger("console", std::vector<SinkPtr>{NullSinkPtr}));
+ Registry::Instance().Register(ConLogger);
}
+#define ZEN_COLOR_YELLOW "\033[0;33m"
+#define ZEN_COLOR_RED "\033[0;31m"
+#define ZEN_BRIGHT_COLOR_RED "\033[1;31m"
+#define ZEN_COLOR_RESET "\033[0m"
+
+class ConsoleFormatter : public Formatter
+{
+public:
+ void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
+ {
+ switch (Msg.GetLevel())
+ {
+ case Warn:
+ fmt::format_to(fmt::appender(Dest), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET);
+ break;
+ case Err:
+ fmt::format_to(fmt::appender(Dest), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET);
+ break;
+ case Critical:
+ fmt::format_to(fmt::appender(Dest), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET);
+ break;
+ default:
+ break;
+ }
+
+ helpers::AppendStringView(Msg.GetPayload(), Dest);
+ Dest.push_back('\n');
+ }
+
+ std::unique_ptr<Formatter> Clone() const override { return std::make_unique<ConsoleFormatter>(); }
+};
+
LoggerRef
ConsoleLog()
{
@@ -394,10 +342,10 @@ ConsoleLog()
std::call_once(ConsoleInitFlag, [&] {
if (!ConLogger)
{
- ConLogger = spdlog::stdout_color_mt("console");
- spdlog::apply_logger_env_levels(ConLogger);
-
- ConLogger->set_pattern("%v");
+ SinkPtr ConsoleSink(new AnsiColorStdoutSink());
+ ConsoleSink->SetFormatter(std::make_unique<ConsoleFormatter>());
+ ConLogger = Ref<Logger>(new Logger("console", std::vector<SinkPtr>{ConsoleSink}));
+ Registry::Instance().Register(ConLogger);
}
});
@@ -405,17 +353,29 @@ ConsoleLog()
}
void
+ResetConsoleLog()
+{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ LoggerRef ConLog = ConsoleLog();
+
+ ConLog->SetFormatter(std::make_unique<ConsoleFormatter>());
+}
+
+void
InitializeLogging()
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- TheDefaultLogger = *spdlog::default_logger_raw();
+ TheDefaultLogger = *Registry::Instance().DefaultLoggerRaw();
}
void
ShutdownLogging()
{
- spdlog::shutdown();
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ Registry::Instance().Shutdown();
TheDefaultLogger = {};
}
@@ -449,7 +409,7 @@ EnableVTMode()
void
FlushLogging()
{
- spdlog::details::registry::instance().flush_all();
+ Registry::Instance().FlushAll();
}
} // namespace zen::logging
@@ -457,21 +417,27 @@ FlushLogging()
namespace zen {
bool
-LoggerRef::ShouldLog(int Level) const
+LoggerRef::ShouldLog(logging::LogLevel Level) const
{
- return SpdLogger->should_log(static_cast<spdlog::level::level_enum>(Level));
+ return m_Logger->ShouldLog(Level);
}
void
-LoggerRef::SetLogLevel(logging::level::LogLevel NewLogLevel)
+LoggerRef::SetLogLevel(logging::LogLevel NewLogLevel)
{
- SpdLogger->set_level(to_spdlog_level(NewLogLevel));
+ m_Logger->SetLevel(NewLogLevel);
}
-logging::level::LogLevel
+logging::LogLevel
LoggerRef::GetLogLevel()
{
- return logging::level::to_logging_level(SpdLogger->level());
+ return m_Logger->GetLevel();
+}
+
+void
+LoggerRef::Flush()
+{
+ m_Logger->Flush();
}
thread_local ScopedActivityBase* t_ScopeStack = nullptr;
@@ -532,6 +498,8 @@ logging_forcelink()
using namespace std::literals;
+TEST_SUITE_BEGIN("core.logging");
+
TEST_CASE("simple.bread")
{
ExtendableStringBuilder<256> Crumbs;
@@ -580,6 +548,8 @@ TEST_CASE("simple.bread")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp
new file mode 100644
index 000000000..540d22359
--- /dev/null
+++ b/src/zencore/logging/ansicolorsink.cpp
@@ -0,0 +1,273 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/logging/ansicolorsink.h>
+#include <zencore/logging/helpers.h>
+#include <zencore/logging/messageonlyformatter.h>
+
+#include <cstdio>
+#include <cstdlib>
+#include <mutex>
+
+#if defined(_WIN32)
+# include <io.h>
+# define ZEN_ISATTY _isatty
+# define ZEN_FILENO _fileno
+#else
+# include <unistd.h>
+# define ZEN_ISATTY isatty
+# define ZEN_FILENO fileno
+#endif
+
+namespace zen::logging {
+
+// Default formatter replicating spdlog's %+ pattern:
+// [YYYY-MM-DD HH:MM:SS.mmm] [logger_name] [level] message\n
+class DefaultConsoleFormatter : public Formatter
+{
+public:
+ void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
+ {
+ // timestamp
+ auto Secs = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch());
+ if (Secs != m_LastLogSecs)
+ {
+ m_LastLogSecs = Secs;
+ m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
+ }
+
+ Dest.push_back('[');
+ helpers::AppendInt(m_CachedLocalTm.tm_year + 1900, Dest);
+ Dest.push_back('-');
+ helpers::Pad2(m_CachedLocalTm.tm_mon + 1, Dest);
+ Dest.push_back('-');
+ helpers::Pad2(m_CachedLocalTm.tm_mday, Dest);
+ Dest.push_back(' ');
+ helpers::Pad2(m_CachedLocalTm.tm_hour, Dest);
+ Dest.push_back(':');
+ helpers::Pad2(m_CachedLocalTm.tm_min, Dest);
+ Dest.push_back(':');
+ helpers::Pad2(m_CachedLocalTm.tm_sec, Dest);
+ Dest.push_back('.');
+ auto Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime());
+ helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest);
+ Dest.push_back(']');
+ Dest.push_back(' ');
+
+ // logger name
+ if (Msg.GetLoggerName().size() > 0)
+ {
+ Dest.push_back('[');
+ helpers::AppendStringView(Msg.GetLoggerName(), Dest);
+ Dest.push_back(']');
+ Dest.push_back(' ');
+ }
+
+ // level (colored range)
+ Dest.push_back('[');
+ Msg.ColorRangeStart = Dest.size();
+ helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest);
+ Msg.ColorRangeEnd = Dest.size();
+ Dest.push_back(']');
+ Dest.push_back(' ');
+
+ // message
+ helpers::AppendStringView(Msg.GetPayload(), Dest);
+ Dest.push_back('\n');
+ }
+
+ std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultConsoleFormatter>(); }
+
+private:
+ std::chrono::seconds m_LastLogSecs{0};
+ std::tm m_CachedLocalTm{};
+};
+
+static constexpr std::string_view s_Reset = "\033[m";
+
+static std::string_view
+GetColorForLevel(LogLevel InLevel)
+{
+ using namespace std::string_view_literals;
+ switch (InLevel)
+ {
+ case Trace:
+ return "\033[37m"sv; // white
+ case Debug:
+ return "\033[36m"sv; // cyan
+ case Info:
+ return "\033[32m"sv; // green
+ case Warn:
+ return "\033[33m\033[1m"sv; // bold yellow
+ case Err:
+ return "\033[31m\033[1m"sv; // bold red
+ case Critical:
+ return "\033[1m\033[41m"sv; // bold on red background
+ default:
+ return s_Reset;
+ }
+}
+
+struct AnsiColorStdoutSink::Impl
+{
+ explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) {}
+
+ static bool IsColorTerminal()
+ {
+ // If stdout is not a TTY, no color
+ if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0)
+ {
+ return false;
+ }
+
+ // NO_COLOR convention (https://no-color.org/)
+ if (std::getenv("NO_COLOR") != nullptr)
+ {
+ return false;
+ }
+
+ // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit")
+ if (std::getenv("COLORTERM") != nullptr)
+ {
+ return true;
+ }
+
+ // Check TERM for known color-capable values
+ const char* Term = std::getenv("TERM");
+ if (Term != nullptr)
+ {
+ std::string_view TermView(Term);
+ // "dumb" terminals do not support color
+ if (TermView == "dumb")
+ {
+ return false;
+ }
+ // Match against known color-capable terminal types.
+ // TERM often includes suffixes like "-256color", so we use substring matching.
+ constexpr std::string_view ColorTerms[] = {
+ "alacritty",
+ "ansi",
+ "color",
+ "console",
+ "cygwin",
+ "gnome",
+ "konsole",
+ "kterm",
+ "linux",
+ "msys",
+ "putty",
+ "rxvt",
+ "screen",
+ "tmux",
+ "vt100",
+ "vt102",
+ "xterm",
+ };
+ for (std::string_view Candidate : ColorTerms)
+ {
+ if (TermView.find(Candidate) != std::string_view::npos)
+ {
+ return true;
+ }
+ }
+ }
+
+#if defined(_WIN32)
+ // Windows console supports ANSI color by default in modern versions
+ return true;
+#else
+ // Unknown terminal — be conservative
+ return false;
+#endif
+ }
+
+ static bool ResolveColorMode(ColorMode Mode)
+ {
+ switch (Mode)
+ {
+ case ColorMode::On:
+ return true;
+ case ColorMode::Off:
+ return false;
+ case ColorMode::Auto:
+ default:
+ return IsColorTerminal();
+ }
+ }
+
+ void Log(const LogMessage& Msg)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+
+ MemoryBuffer Formatted;
+ m_Formatter->Format(Msg, Formatted);
+
+ if (m_UseColor && Msg.ColorRangeEnd > Msg.ColorRangeStart)
+ {
+ // Print pre-color range
+ fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File);
+
+ // Print color
+ std::string_view Color = GetColorForLevel(Msg.GetLevel());
+ fwrite(Color.data(), 1, Color.size(), m_File);
+
+ // Print colored range
+ fwrite(Formatted.data() + Msg.ColorRangeStart, 1, Msg.ColorRangeEnd - Msg.ColorRangeStart, m_File);
+
+ // Reset color
+ fwrite(s_Reset.data(), 1, s_Reset.size(), m_File);
+
+ // Print remainder
+ fwrite(Formatted.data() + Msg.ColorRangeEnd, 1, Formatted.size() - Msg.ColorRangeEnd, m_File);
+ }
+ else
+ {
+ fwrite(Formatted.data(), 1, Formatted.size(), m_File);
+ }
+
+ fflush(m_File);
+ }
+
+ void Flush()
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ fflush(m_File);
+ }
+
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ m_Formatter = std::move(InFormatter);
+ }
+
+private:
+ std::mutex m_Mutex;
+ std::unique_ptr<Formatter> m_Formatter;
+ FILE* m_File = stdout;
+ bool m_UseColor = true;
+};
+
+AnsiColorStdoutSink::AnsiColorStdoutSink(ColorMode Mode) : m_Impl(std::make_unique<Impl>(Mode))
+{
+}
+
+AnsiColorStdoutSink::~AnsiColorStdoutSink() = default;
+
+void
+AnsiColorStdoutSink::Log(const LogMessage& Msg)
+{
+ m_Impl->Log(Msg);
+}
+
+void
+AnsiColorStdoutSink::Flush()
+{
+ m_Impl->Flush();
+}
+
+void
+AnsiColorStdoutSink::SetFormatter(std::unique_ptr<Formatter> InFormatter)
+{
+ m_Impl->SetFormatter(std::move(InFormatter));
+}
+
+} // namespace zen::logging
diff --git a/src/zencore/logging/asyncsink.cpp b/src/zencore/logging/asyncsink.cpp
new file mode 100644
index 000000000..02bf9f3ba
--- /dev/null
+++ b/src/zencore/logging/asyncsink.cpp
@@ -0,0 +1,212 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/logging/asyncsink.h>
+
+#include <zencore/blockingqueue.h>
+#include <zencore/logging/logmsg.h>
+#include <zencore/thread.h>
+
+#include <future>
+#include <string>
+#include <thread>
+
+namespace zen::logging {
+
+struct AsyncLogMessage
+{
+ enum class Type : uint8_t
+ {
+ Log,
+ Flush,
+ Shutdown
+ };
+
+ Type MsgType = Type::Log;
+
+ // Points to the LogPoint from upstream logging code. LogMessage guarantees
+ // this is always valid (either a static LogPoint from ZEN_LOG macros or one
+ // of the per-level default LogPoints).
+ const LogPoint* Point = nullptr;
+
+ int ThreadId = 0;
+ std::string OwnedPayload;
+ std::string OwnedLoggerName;
+ std::chrono::system_clock::time_point Time;
+
+ std::shared_ptr<std::promise<void>> FlushPromise;
+};
+
+struct AsyncSink::Impl
+{
+ explicit Impl(std::vector<SinkPtr> InSinks) : m_Sinks(std::move(InSinks))
+ {
+ m_WorkerThread = std::thread([this]() {
+ zen::SetCurrentThreadName("AsyncLog");
+ WorkerLoop();
+ });
+ }
+
+ ~Impl()
+ {
+ AsyncLogMessage ShutdownMsg;
+ ShutdownMsg.MsgType = AsyncLogMessage::Type::Shutdown;
+ m_Queue.Enqueue(std::move(ShutdownMsg));
+
+ if (m_WorkerThread.joinable())
+ {
+ m_WorkerThread.join();
+ }
+ }
+
+ void Log(const LogMessage& Msg)
+ {
+ AsyncLogMessage AsyncMsg;
+ AsyncMsg.OwnedPayload = std::string(Msg.GetPayload());
+ AsyncMsg.OwnedLoggerName = std::string(Msg.GetLoggerName());
+ AsyncMsg.ThreadId = Msg.GetThreadId();
+ AsyncMsg.Time = Msg.GetTime();
+ AsyncMsg.Point = &Msg.GetLogPoint();
+ AsyncMsg.MsgType = AsyncLogMessage::Type::Log;
+
+ m_Queue.Enqueue(std::move(AsyncMsg));
+ }
+
+ void Flush()
+ {
+ auto Promise = std::make_shared<std::promise<void>>();
+ auto Future = Promise->get_future();
+
+ AsyncLogMessage FlushMsg;
+ FlushMsg.MsgType = AsyncLogMessage::Type::Flush;
+ FlushMsg.FlushPromise = std::move(Promise);
+
+ m_Queue.Enqueue(std::move(FlushMsg));
+
+ Future.get();
+ }
+
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter)
+ {
+ for (auto& CurrentSink : m_Sinks)
+ {
+ CurrentSink->SetFormatter(InFormatter->Clone());
+ }
+ }
+
+private:
+ void ForwardLogToSinks(const AsyncLogMessage& AsyncMsg)
+ {
+ LogMessage Reconstructed(*AsyncMsg.Point, AsyncMsg.OwnedLoggerName, AsyncMsg.OwnedPayload);
+ Reconstructed.SetTime(AsyncMsg.Time);
+ Reconstructed.SetThreadId(AsyncMsg.ThreadId);
+
+ for (auto& CurrentSink : m_Sinks)
+ {
+ if (CurrentSink->ShouldLog(Reconstructed.GetLevel()))
+ {
+ try
+ {
+ CurrentSink->Log(Reconstructed);
+ }
+ catch (const std::exception&)
+ {
+ }
+ }
+ }
+ }
+
+ void FlushSinks()
+ {
+ for (auto& CurrentSink : m_Sinks)
+ {
+ try
+ {
+ CurrentSink->Flush();
+ }
+ catch (const std::exception&)
+ {
+ }
+ }
+ }
+
+ void WorkerLoop()
+ {
+ AsyncLogMessage Msg;
+ while (m_Queue.WaitAndDequeue(Msg))
+ {
+ switch (Msg.MsgType)
+ {
+ case AsyncLogMessage::Type::Log:
+ {
+ ForwardLogToSinks(Msg);
+ break;
+ }
+
+ case AsyncLogMessage::Type::Flush:
+ {
+ FlushSinks();
+ if (Msg.FlushPromise)
+ {
+ Msg.FlushPromise->set_value();
+ }
+ break;
+ }
+
+ case AsyncLogMessage::Type::Shutdown:
+ {
+ m_Queue.CompleteAdding();
+
+ AsyncLogMessage Remaining;
+ while (m_Queue.WaitAndDequeue(Remaining))
+ {
+ if (Remaining.MsgType == AsyncLogMessage::Type::Log)
+ {
+ ForwardLogToSinks(Remaining);
+ }
+ else if (Remaining.MsgType == AsyncLogMessage::Type::Flush)
+ {
+ FlushSinks();
+ if (Remaining.FlushPromise)
+ {
+ Remaining.FlushPromise->set_value();
+ }
+ }
+ }
+
+ FlushSinks();
+ return;
+ }
+ }
+ }
+ }
+
+ std::vector<SinkPtr> m_Sinks;
+ BlockingQueue<AsyncLogMessage> m_Queue;
+ std::thread m_WorkerThread;
+};
+
+AsyncSink::AsyncSink(std::vector<SinkPtr> InSinks) : m_Impl(std::make_unique<Impl>(std::move(InSinks)))
+{
+}
+
+AsyncSink::~AsyncSink() = default;
+
+void
+AsyncSink::Log(const LogMessage& Msg)
+{
+ m_Impl->Log(Msg);
+}
+
+void
+AsyncSink::Flush()
+{
+ m_Impl->Flush();
+}
+
+void
+AsyncSink::SetFormatter(std::unique_ptr<Formatter> InFormatter)
+{
+ m_Impl->SetFormatter(std::move(InFormatter));
+}
+
+} // namespace zen::logging
diff --git a/src/zencore/logging/logger.cpp b/src/zencore/logging/logger.cpp
new file mode 100644
index 000000000..dd1675bb1
--- /dev/null
+++ b/src/zencore/logging/logger.cpp
@@ -0,0 +1,142 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/logging/logger.h>
+#include <zencore/thread.h>
+
+#include <string>
+#include <vector>
+
+namespace zen::logging {
+
+struct Logger::Impl
+{
+ std::string m_Name;
+ std::vector<SinkPtr> m_Sinks;
+ ErrorHandler* m_ErrorHandler = nullptr;
+};
+
+Logger::Logger(std::string_view InName, SinkPtr InSink) : m_Impl(std::make_unique<Impl>())
+{
+ m_Impl->m_Name = InName;
+ m_Impl->m_Sinks.push_back(std::move(InSink));
+}
+
+Logger::Logger(std::string_view InName, std::span<const SinkPtr> InSinks) : m_Impl(std::make_unique<Impl>())
+{
+ m_Impl->m_Name = InName;
+ m_Impl->m_Sinks.assign(InSinks.begin(), InSinks.end());
+}
+
+Logger::~Logger() = default;
+
+void
+Logger::Log(const LogPoint& Point, fmt::format_args Args)
+{
+ if (!ShouldLog(Point.Level))
+ {
+ return;
+ }
+
+ fmt::basic_memory_buffer<char, 250> Buffer;
+ fmt::vformat_to(fmt::appender(Buffer), Point.FormatString, Args);
+
+ LogMessage LogMsg(Point, m_Impl->m_Name, std::string_view(Buffer.data(), Buffer.size()));
+ LogMsg.SetThreadId(GetCurrentThreadId());
+ SinkIt(LogMsg);
+ FlushIfNeeded(Point.Level);
+}
+
+void
+Logger::SinkIt(const LogMessage& Msg)
+{
+ for (auto& CurrentSink : m_Impl->m_Sinks)
+ {
+ if (CurrentSink->ShouldLog(Msg.GetLevel()))
+ {
+ try
+ {
+ CurrentSink->Log(Msg);
+ }
+ catch (const std::exception& Ex)
+ {
+ if (m_Impl->m_ErrorHandler)
+ {
+ m_Impl->m_ErrorHandler->HandleError(Ex.what());
+ }
+ }
+ }
+ }
+}
+
+void
+Logger::FlushIfNeeded(LogLevel InLevel)
+{
+ if (InLevel >= m_FlushLevel.load(std::memory_order_relaxed))
+ {
+ Flush();
+ }
+}
+
+void
+Logger::Flush()
+{
+ for (auto& CurrentSink : m_Impl->m_Sinks)
+ {
+ try
+ {
+ CurrentSink->Flush();
+ }
+ catch (const std::exception& Ex)
+ {
+ if (m_Impl->m_ErrorHandler)
+ {
+ m_Impl->m_ErrorHandler->HandleError(Ex.what());
+ }
+ }
+ }
+}
+
+void
+Logger::SetSinks(std::vector<SinkPtr> InSinks)
+{
+ m_Impl->m_Sinks = std::move(InSinks);
+}
+
+void
+Logger::AddSink(SinkPtr InSink)
+{
+ m_Impl->m_Sinks.push_back(std::move(InSink));
+}
+
+void
+Logger::SetErrorHandler(ErrorHandler* Handler)
+{
+ m_Impl->m_ErrorHandler = Handler;
+}
+
+void
+Logger::SetFormatter(std::unique_ptr<Formatter> InFormatter)
+{
+ for (auto& CurrentSink : m_Impl->m_Sinks)
+ {
+ CurrentSink->SetFormatter(InFormatter->Clone());
+ }
+}
+
+std::string_view
+Logger::Name() const
+{
+ return m_Impl->m_Name;
+}
+
+Ref<Logger>
+Logger::Clone(std::string_view NewName) const
+{
+ Ref<Logger> Cloned(new Logger(NewName, m_Impl->m_Sinks));
+ Cloned->SetLevel(m_Level.load(std::memory_order_relaxed));
+ Cloned->SetFlushLevel(m_FlushLevel.load(std::memory_order_relaxed));
+ Cloned->SetErrorHandler(m_Impl->m_ErrorHandler);
+ return Cloned;
+}
+
+} // namespace zen::logging
diff --git a/src/zencore/logging/msvcsink.cpp b/src/zencore/logging/msvcsink.cpp
new file mode 100644
index 000000000..457a4d6e1
--- /dev/null
+++ b/src/zencore/logging/msvcsink.cpp
@@ -0,0 +1,80 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+
+# include <zencore/logging/helpers.h>
+# include <zencore/logging/messageonlyformatter.h>
+# include <zencore/logging/msvcsink.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <Windows.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::logging {
+
+// Default formatter for MSVC debug output: [level] message\n
+// For error/critical messages with source info, prepends file(line): so that
+// the message is clickable in the Visual Studio Output window.
+class DefaultMsvcFormatter : public Formatter
+{
+public:
+ void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
+ {
+ const auto& Source = Msg.GetSource();
+ if (Msg.GetLevel() >= LogLevel::Err && Source)
+ {
+ helpers::AppendStringView(Source.Filename, Dest);
+ Dest.push_back('(');
+ helpers::AppendInt(Source.Line, Dest);
+ Dest.push_back(')');
+ Dest.push_back(':');
+ Dest.push_back(' ');
+ }
+
+ Dest.push_back('[');
+ helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest);
+ Dest.push_back(']');
+ Dest.push_back(' ');
+ helpers::AppendStringView(Msg.GetPayload(), Dest);
+ Dest.push_back('\n');
+ }
+
+ std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultMsvcFormatter>(); }
+};
+
+MsvcSink::MsvcSink() : m_Formatter(std::make_unique<DefaultMsvcFormatter>())
+{
+}
+
+void
+MsvcSink::Log(const LogMessage& Msg)
+{
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+
+ MemoryBuffer Formatted;
+ m_Formatter->Format(Msg, Formatted);
+
+ // Null-terminate for OutputDebugStringA
+ Formatted.push_back('\0');
+
+ OutputDebugStringA(Formatted.data());
+}
+
+void
+MsvcSink::Flush()
+{
+ // Nothing to flush for OutputDebugString
+}
+
+void
+MsvcSink::SetFormatter(std::unique_ptr<Formatter> InFormatter)
+{
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ m_Formatter = std::move(InFormatter);
+}
+
+} // namespace zen::logging
+
+#endif // ZEN_PLATFORM_WINDOWS
diff --git a/src/zencore/logging/registry.cpp b/src/zencore/logging/registry.cpp
new file mode 100644
index 000000000..3ed1fb0df
--- /dev/null
+++ b/src/zencore/logging/registry.cpp
@@ -0,0 +1,330 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/logging/registry.h>
+
+#include <zencore/logging/ansicolorsink.h>
+#include <zencore/logging/messageonlyformatter.h>
+
+#include <atomic>
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+#include <unordered_map>
+
+namespace zen::logging {
+
+struct Registry::Impl
+{
+ Impl()
+ {
+ // Create default logger with a stdout color sink
+ SinkPtr DefaultSink(new AnsiColorStdoutSink());
+ m_DefaultLogger = Ref<Logger>(new Logger("", DefaultSink));
+ m_Loggers[""] = m_DefaultLogger;
+ }
+
+ ~Impl() { StopPeriodicFlush(); }
+
+ void Register(Ref<Logger> InLogger)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ if (m_ErrorHandler)
+ {
+ InLogger->SetErrorHandler(m_ErrorHandler);
+ }
+ m_Loggers[std::string(InLogger->Name())] = std::move(InLogger);
+ }
+
+ void Drop(const std::string& Name)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ m_Loggers.erase(Name);
+ }
+
+ Ref<Logger> Get(const std::string& Name)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ auto It = m_Loggers.find(Name);
+ if (It != m_Loggers.end())
+ {
+ return It->second;
+ }
+ return {};
+ }
+
+ void SetDefaultLogger(Ref<Logger> InLogger)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ if (InLogger)
+ {
+ m_Loggers[std::string(InLogger->Name())] = InLogger;
+ }
+ m_DefaultLogger = std::move(InLogger);
+ }
+
+ Logger* DefaultLoggerRaw() { return m_DefaultLogger.Get(); }
+
+ Ref<Logger> DefaultLogger()
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ return m_DefaultLogger;
+ }
+
+ void SetGlobalLevel(LogLevel Level)
+ {
+ m_GlobalLevel.store(Level, std::memory_order_relaxed);
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ CurLogger->SetLevel(Level);
+ }
+ }
+
+ LogLevel GetGlobalLevel() const { return m_GlobalLevel.load(std::memory_order_relaxed); }
+
+ void SetLevels(Registry::LogLevels Levels, LogLevel* DefaultLevel)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+
+ if (DefaultLevel)
+ {
+ m_GlobalLevel.store(*DefaultLevel, std::memory_order_relaxed);
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ CurLogger->SetLevel(*DefaultLevel);
+ }
+ }
+
+ for (auto& [LoggerName, Level] : Levels)
+ {
+ auto It = m_Loggers.find(LoggerName);
+ if (It != m_Loggers.end())
+ {
+ It->second->SetLevel(Level);
+ }
+ }
+ }
+
+ void FlushAll()
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ try
+ {
+ CurLogger->Flush();
+ }
+ catch (const std::exception&)
+ {
+ }
+ }
+ }
+
+ void FlushOn(LogLevel Level)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ m_FlushLevel = Level;
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ CurLogger->SetFlushLevel(Level);
+ }
+ }
+
+ void FlushEvery(std::chrono::seconds Interval)
+ {
+ StopPeriodicFlush();
+
+ m_PeriodicFlushRunning.store(true, std::memory_order_relaxed);
+
+ m_FlushThread = std::thread([this, Interval] {
+ while (m_PeriodicFlushRunning.load(std::memory_order_relaxed))
+ {
+ {
+ std::unique_lock<std::mutex> Lock(m_PeriodicFlushMutex);
+ m_PeriodicFlushCv.wait_for(Lock, Interval, [this] { return !m_PeriodicFlushRunning.load(std::memory_order_relaxed); });
+ }
+
+ if (m_PeriodicFlushRunning.load(std::memory_order_relaxed))
+ {
+ FlushAll();
+ }
+ }
+ });
+ }
+
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ CurLogger->SetFormatter(InFormatter->Clone());
+ }
+ }
+
+ void ApplyAll(void (*Func)(void*, Ref<Logger>), void* Context)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ Func(Context, CurLogger);
+ }
+ }
+
+ void SetErrorHandler(ErrorHandler* Handler)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ m_ErrorHandler = Handler;
+ for (auto& [Name, CurLogger] : m_Loggers)
+ {
+ CurLogger->SetErrorHandler(Handler);
+ }
+ }
+
+ void Shutdown()
+ {
+ StopPeriodicFlush();
+ FlushAll();
+
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+ m_Loggers.clear();
+ m_DefaultLogger = nullptr;
+ }
+
+private:
+ void StopPeriodicFlush()
+ {
+ if (m_FlushThread.joinable())
+ {
+ m_PeriodicFlushRunning.store(false, std::memory_order_relaxed);
+ {
+ std::lock_guard<std::mutex> Lock(m_PeriodicFlushMutex);
+ m_PeriodicFlushCv.notify_one();
+ }
+ m_FlushThread.join();
+ }
+ }
+
+ std::mutex m_Mutex;
+ std::unordered_map<std::string, Ref<Logger>> m_Loggers;
+ Ref<Logger> m_DefaultLogger;
+ std::atomic<LogLevel> m_GlobalLevel{Trace};
+ LogLevel m_FlushLevel{Off};
+ ErrorHandler* m_ErrorHandler = nullptr;
+
+ // Periodic flush
+ std::atomic<bool> m_PeriodicFlushRunning{false};
+ std::mutex m_PeriodicFlushMutex;
+ std::condition_variable m_PeriodicFlushCv;
+ std::thread m_FlushThread;
+};
+
+Registry&
+Registry::Instance()
+{
+ static Registry s_Instance;
+ return s_Instance;
+}
+
+Registry::Registry() : m_Impl(std::make_unique<Impl>())
+{
+}
+
+Registry::~Registry() = default;
+
+void
+Registry::Register(Ref<Logger> InLogger)
+{
+ m_Impl->Register(std::move(InLogger));
+}
+
+void
+Registry::Drop(const std::string& Name)
+{
+ m_Impl->Drop(Name);
+}
+
+Ref<Logger>
+Registry::Get(const std::string& Name)
+{
+ return m_Impl->Get(Name);
+}
+
+void
+Registry::SetDefaultLogger(Ref<Logger> InLogger)
+{
+ m_Impl->SetDefaultLogger(std::move(InLogger));
+}
+
+Logger*
+Registry::DefaultLoggerRaw()
+{
+ return m_Impl->DefaultLoggerRaw();
+}
+
+Ref<Logger>
+Registry::DefaultLogger()
+{
+ return m_Impl->DefaultLogger();
+}
+
+void
+Registry::SetGlobalLevel(LogLevel Level)
+{
+ m_Impl->SetGlobalLevel(Level);
+}
+
+LogLevel
+Registry::GetGlobalLevel() const
+{
+ return m_Impl->GetGlobalLevel();
+}
+
+void
+Registry::SetLevels(LogLevels Levels, LogLevel* DefaultLevel)
+{
+ m_Impl->SetLevels(Levels, DefaultLevel);
+}
+
+void
+Registry::FlushAll()
+{
+ m_Impl->FlushAll();
+}
+
+void
+Registry::FlushOn(LogLevel Level)
+{
+ m_Impl->FlushOn(Level);
+}
+
+void
+Registry::FlushEvery(std::chrono::seconds Interval)
+{
+ m_Impl->FlushEvery(Interval);
+}
+
+void
+Registry::SetFormatter(std::unique_ptr<Formatter> InFormatter)
+{
+ m_Impl->SetFormatter(std::move(InFormatter));
+}
+
+void
+Registry::ApplyAllImpl(void (*Func)(void*, Ref<Logger>), void* Context)
+{
+ m_Impl->ApplyAll(Func, Context);
+}
+
+void
+Registry::SetErrorHandler(ErrorHandler* Handler)
+{
+ m_Impl->SetErrorHandler(Handler);
+}
+
+void
+Registry::Shutdown()
+{
+ m_Impl->Shutdown();
+}
+
+} // namespace zen::logging
diff --git a/src/zencore/logging/tracesink.cpp b/src/zencore/logging/tracesink.cpp
new file mode 100644
index 000000000..8a6f4e40c
--- /dev/null
+++ b/src/zencore/logging/tracesink.cpp
@@ -0,0 +1,92 @@
+
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/logbase.h>
+#include <zencore/logging/tracesink.h>
+#include <zencore/string.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+
+#if ZEN_WITH_TRACE
+
+namespace zen::logging {
+
+UE_TRACE_CHANNEL_DEFINE(LogChannel)
+
+UE_TRACE_EVENT_BEGIN(Logging, LogCategory, NoSync | Important)
+ UE_TRACE_EVENT_FIELD(const void*, CategoryPointer)
+ UE_TRACE_EVENT_FIELD(uint8_t, DefaultVerbosity)
+ UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, Name)
+UE_TRACE_EVENT_END()
+
+UE_TRACE_EVENT_BEGIN(Logging, LogMessageSpec, NoSync | Important)
+ UE_TRACE_EVENT_FIELD(const void*, LogPoint)
+ UE_TRACE_EVENT_FIELD(const void*, CategoryPointer)
+ UE_TRACE_EVENT_FIELD(int32_t, Line)
+ UE_TRACE_EVENT_FIELD(uint8_t, Verbosity)
+ UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, FileName)
+ UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, FormatString)
+UE_TRACE_EVENT_END()
+
+UE_TRACE_EVENT_BEGIN(Logging, LogMessage, NoSync)
+ UE_TRACE_EVENT_FIELD(const void*, LogPoint)
+ UE_TRACE_EVENT_FIELD(uint64_t, Cycle)
+ UE_TRACE_EVENT_FIELD(uint8_t[], FormatArgs)
+UE_TRACE_EVENT_END()
+
+void
+TraceLogCategory(const logging::Logger* Category, const char* Name, logging::LogLevel DefaultVerbosity)
+{
+ uint16_t NameLen = uint16_t(strlen(Name));
+ UE_TRACE_LOG(Logging, LogCategory, LogChannel, NameLen * sizeof(ANSICHAR))
+ << LogCategory.CategoryPointer(Category) << LogCategory.DefaultVerbosity(uint8_t(DefaultVerbosity))
+ << LogCategory.Name(Name, NameLen);
+}
+
+void
+TraceLogMessageSpec(const void* LogPoint,
+ const logging::Logger* Category,
+ logging::LogLevel Verbosity,
+ const std::string_view File,
+ int32_t Line,
+ const std::string_view Format)
+{
+ uint16_t FileNameLen = uint16_t(File.size());
+ uint16_t FormatStringLen = uint16_t(Format.size());
+ uint32_t DataSize = (FileNameLen * sizeof(ANSICHAR)) + (FormatStringLen * sizeof(ANSICHAR));
+ UE_TRACE_LOG(Logging, LogMessageSpec, LogChannel, DataSize)
+ << LogMessageSpec.LogPoint(LogPoint) << LogMessageSpec.CategoryPointer(Category) << LogMessageSpec.Line(Line)
+ << LogMessageSpec.Verbosity(uint8_t(Verbosity)) << LogMessageSpec.FileName(File.data(), FileNameLen)
+ << LogMessageSpec.FormatString(Format.data(), FormatStringLen);
+}
+
+void
+TraceLogMessageInternal(const void* LogPoint, int32_t EncodedFormatArgsSize, const uint8_t* EncodedFormatArgs)
+{
+ UE_TRACE_LOG(Logging, LogMessage, LogChannel) << LogMessage.LogPoint(LogPoint) << LogMessage.Cycle(GetHifreqTimerValue())
+ << LogMessage.FormatArgs(EncodedFormatArgs, EncodedFormatArgsSize);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+TraceSink::Log(const LogMessage& Msg)
+{
+ ZEN_UNUSED(Msg);
+}
+
+void
+TraceSink::Flush()
+{
+}
+
+void
+TraceSink::SetFormatter(std::unique_ptr<Formatter> /*InFormatter*/)
+{
+ // This sink doesn't use a formatter since it just forwards the raw format
+ // args to the trace system
+}
+
+} // namespace zen::logging
+
+#endif
diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp
index 4ec145697..f8cfee3ac 100644
--- a/src/zencore/md5.cpp
+++ b/src/zencore/md5.cpp
@@ -56,9 +56,9 @@ struct MD5_CTX
unsigned char digest[16]; /* actual digest after MD5Final call */
};
-void MD5Init();
-void MD5Update();
-void MD5Final();
+void MD5Init(MD5_CTX* mdContext);
+void MD5Update(MD5_CTX* mdContext, unsigned char* inBuf, unsigned int inLen);
+void MD5Final(MD5_CTX* mdContext);
/*
**********************************************************************
@@ -342,6 +342,23 @@ Transform(uint32_t* buf, uint32_t* in)
#undef G
#undef H
#undef I
+#undef ROTATE_LEFT
+#undef S11
+#undef S12
+#undef S13
+#undef S14
+#undef S21
+#undef S22
+#undef S23
+#undef S24
+#undef S31
+#undef S32
+#undef S33
+#undef S34
+#undef S41
+#undef S42
+#undef S43
+#undef S44
namespace zen {
@@ -353,28 +370,32 @@ MD5 MD5::Zero; // Initialized to all zeroes
MD5Stream::MD5Stream()
{
+ static_assert(sizeof(MD5_CTX) <= sizeof(m_Context));
Reset();
}
void
MD5Stream::Reset()
{
+ MD5Init(reinterpret_cast<MD5_CTX*>(m_Context));
}
MD5Stream&
MD5Stream::Append(const void* Data, size_t ByteCount)
{
- ZEN_UNUSED(Data);
- ZEN_UNUSED(ByteCount);
-
+ MD5Update(reinterpret_cast<MD5_CTX*>(m_Context), (unsigned char*)Data, (unsigned int)ByteCount);
return *this;
}
MD5
MD5Stream::GetHash()
{
- MD5 md5{};
+ MD5_CTX FinalCtx;
+ memcpy(&FinalCtx, m_Context, sizeof(MD5_CTX));
+ MD5Final(&FinalCtx);
+ MD5 md5{};
+ memcpy(md5.Hash, FinalCtx.digest, 16);
return md5;
}
@@ -391,7 +412,7 @@ MD5::FromHexString(const char* string)
{
MD5 md5;
- ParseHexBytes(string, 40, md5.Hash);
+ ParseHexBytes(string, 2 * sizeof md5.Hash, md5.Hash);
return md5;
}
@@ -411,7 +432,7 @@ MD5::ToHexString(StringBuilderBase& outBuilder) const
char str[41];
ToHexString(str);
- outBuilder.AppendRange(str, &str[40]);
+ outBuilder.AppendRange(str, &str[StringLength]);
return outBuilder;
}
@@ -437,6 +458,8 @@ md5_forcelink()
// return md5text;
// }
+TEST_SUITE_BEGIN("core.md5");
+
TEST_CASE("MD5")
{
using namespace std::literals;
@@ -451,13 +474,15 @@ TEST_CASE("MD5")
MD5::String_t Buffer;
Result.ToHexString(Buffer);
- CHECK(Output.compare(Buffer));
+ CHECK(Output.compare(Buffer) == 0);
MD5 Reresult = MD5::FromHexString(Buffer);
Reresult.ToHexString(Buffer);
- CHECK(Output.compare(Buffer));
+ CHECK(Output.compare(Buffer) == 0);
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/memoryview.cpp b/src/zencore/memoryview.cpp
index 1f6a6996c..1654b1766 100644
--- a/src/zencore/memoryview.cpp
+++ b/src/zencore/memoryview.cpp
@@ -18,6 +18,8 @@ namespace zen {
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.memoryview");
+
TEST_CASE("MemoryView")
{
{
@@ -35,6 +37,8 @@ TEST_CASE("MemoryView")
CHECK(MakeMemoryView<float>({1.0f, 1.2f}).GetSize() == 8);
}
+TEST_SUITE_END();
+
void
memory_forcelink()
{
diff --git a/src/zencore/memtrack/callstacktrace.cpp b/src/zencore/memtrack/callstacktrace.cpp
index a5b7fede6..4a7068568 100644
--- a/src/zencore/memtrack/callstacktrace.cpp
+++ b/src/zencore/memtrack/callstacktrace.cpp
@@ -169,13 +169,13 @@ private:
std::atomic_uint64_t Key;
std::atomic_uint32_t Value;
- inline uint64 GetKey() const { return Key.load(std::memory_order_relaxed); }
+ inline uint64 GetKey() const { return Key.load(std::memory_order_acquire); }
inline uint32_t GetValue() const { return Value.load(std::memory_order_relaxed); }
- inline bool IsEmpty() const { return Key.load(std::memory_order_relaxed) == 0; }
+ inline bool IsEmpty() const { return Key.load(std::memory_order_acquire) == 0; }
inline void SetKeyValue(uint64_t InKey, uint32_t InValue)
{
- Value.store(InValue, std::memory_order_release);
- Key.store(InKey, std::memory_order_relaxed);
+ Value.store(InValue, std::memory_order_relaxed);
+ Key.store(InKey, std::memory_order_release);
}
static inline uint32_t KeyHash(uint64_t Key) { return static_cast<uint32_t>(Key); }
static inline void ClearEntries(FEncounteredCallstackSetEntry* Entries, int32_t EntryCount)
diff --git a/src/zencore/memtrack/tagtrace.cpp b/src/zencore/memtrack/tagtrace.cpp
index 70a74365d..fca4a2ec3 100644
--- a/src/zencore/memtrack/tagtrace.cpp
+++ b/src/zencore/memtrack/tagtrace.cpp
@@ -186,7 +186,7 @@ FTagTrace::AnnounceSpecialTags() const
{
auto EmitTag = [](const char16_t* DisplayString, int32_t Tag, int32_t ParentTag) {
const uint32_t DisplayLen = (uint32_t)StringLength(DisplayString);
- UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(ANSICHAR))
+ UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(char16_t))
<< TagSpec.Tag(Tag) << TagSpec.Parent(ParentTag) << TagSpec.Display(DisplayString, DisplayLen);
};
diff --git a/src/zencore/mpscqueue.cpp b/src/zencore/mpscqueue.cpp
index 29c76c3ca..bdd22e20c 100644
--- a/src/zencore/mpscqueue.cpp
+++ b/src/zencore/mpscqueue.cpp
@@ -7,7 +7,8 @@
namespace zen {
-#if ZEN_WITH_TESTS && 0
+#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.mpscqueue");
TEST_CASE("mpsc")
{
MpscQueue<std::string> Queue;
@@ -15,6 +16,7 @@ TEST_CASE("mpsc")
std::optional<std::string> Value = Queue.Dequeue();
CHECK_EQ(Value, "hello");
}
+TEST_SUITE_END();
#endif
void
@@ -22,4 +24,4 @@ mpscqueue_forcelink()
{
}
-} // namespace zen \ No newline at end of file
+} // namespace zen
diff --git a/src/zencore/parallelwork.cpp b/src/zencore/parallelwork.cpp
index d86d5815f..94696f479 100644
--- a/src/zencore/parallelwork.cpp
+++ b/src/zencore/parallelwork.cpp
@@ -157,6 +157,8 @@ ParallelWork::RethrowErrors()
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.parallelwork");
+
TEST_CASE("parallellwork.nowork")
{
std::atomic<bool> AbortFlag;
@@ -255,6 +257,8 @@ TEST_CASE("parallellwork.limitqueue")
Work.Wait();
}
+TEST_SUITE_END();
+
void
parallellwork_forcelink()
{
diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp
index 56849a10d..f657869dc 100644
--- a/src/zencore/process.cpp
+++ b/src/zencore/process.cpp
@@ -9,6 +9,7 @@
#include <zencore/string.h>
#include <zencore/testing.h>
#include <zencore/timer.h>
+#include <zencore/trace.h>
#include <thread>
@@ -490,6 +491,8 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma
LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr;
LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr;
+ const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid();
+
DWORD CreationFlags = 0;
if (Options.Flags & CreateProcOptions::Flag_NewConsole)
{
@@ -503,6 +506,10 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma
{
CreationFlags |= CREATE_NEW_PROCESS_GROUP;
}
+ if (AssignToJob)
+ {
+ CreationFlags |= CREATE_SUSPENDED;
+ }
const wchar_t* WorkingDir = nullptr;
if (Options.WorkingDirectory != nullptr)
@@ -571,6 +578,15 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma
return nullptr;
}
+ if (AssignToJob)
+ {
+ if (!Options.AssignToJob->AssignProcess(ProcessInfo.hProcess))
+ {
+ ZEN_WARN("Failed to assign newly created process to job object");
+ }
+ ResumeThread(ProcessInfo.hThread);
+ }
+
CloseHandle(ProcessInfo.hThread);
return ProcessInfo.hProcess;
}
@@ -644,6 +660,8 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C
};
PROCESS_INFORMATION ProcessInfo = {};
+ const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid();
+
if (Options.Flags & CreateProcOptions::Flag_NewConsole)
{
CreateProcFlags |= CREATE_NEW_CONSOLE;
@@ -652,6 +670,10 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C
{
CreateProcFlags |= CREATE_NO_WINDOW;
}
+ if (AssignToJob)
+ {
+ CreateProcFlags |= CREATE_SUSPENDED;
+ }
ExtendableWideStringBuilder<256> CommandLineZ;
CommandLineZ << CommandLine;
@@ -679,6 +701,15 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C
return nullptr;
}
+ if (AssignToJob)
+ {
+ if (!Options.AssignToJob->AssignProcess(ProcessInfo.hProcess))
+ {
+ ZEN_WARN("Failed to assign newly created process to job object");
+ }
+ ResumeThread(ProcessInfo.hThread);
+ }
+
CloseHandle(ProcessInfo.hThread);
return ProcessInfo.hProcess;
}
@@ -715,6 +746,8 @@ CreateProcElevated(const std::filesystem::path& Executable, std::string_view Com
CreateProcResult
CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options)
{
+ ZEN_TRACE_CPU("CreateProc");
+
#if ZEN_PLATFORM_WINDOWS
if (Options.Flags & CreateProcOptions::Flag_Unelevated)
{
@@ -746,6 +779,17 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine
ZEN_UNUSED(Result);
}
+ if (!Options.StdoutFile.empty())
+ {
+ int Fd = open(Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
+ if (Fd >= 0)
+ {
+ dup2(Fd, STDOUT_FILENO);
+ dup2(Fd, STDERR_FILENO);
+ close(Fd);
+ }
+ }
+
if (execv(Executable.c_str(), ArgV.data()) < 0)
{
ThrowLastError("Failed to exec() a new process image");
@@ -845,6 +889,65 @@ ProcessMonitor::IsActive() const
//////////////////////////////////////////////////////////////////////////
+#if ZEN_PLATFORM_WINDOWS
+JobObject::JobObject() = default;
+
+JobObject::~JobObject()
+{
+ if (m_JobHandle)
+ {
+ CloseHandle(m_JobHandle);
+ m_JobHandle = nullptr;
+ }
+}
+
+void
+JobObject::Initialize()
+{
+ ZEN_ASSERT(m_JobHandle == nullptr, "JobObject already initialized");
+
+ m_JobHandle = CreateJobObjectW(nullptr, nullptr);
+ if (!m_JobHandle)
+ {
+ ZEN_WARN("Failed to create job object: {}", zen::GetLastError());
+ return;
+ }
+
+ JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {};
+ LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
+
+ if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo)))
+ {
+ ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError());
+ CloseHandle(m_JobHandle);
+ m_JobHandle = nullptr;
+ }
+}
+
+bool
+JobObject::AssignProcess(void* ProcessHandle)
+{
+ ZEN_ASSERT(m_JobHandle != nullptr, "JobObject not initialized");
+ ZEN_ASSERT(ProcessHandle != nullptr, "ProcessHandle is null");
+
+ if (!AssignProcessToJobObject(m_JobHandle, ProcessHandle))
+ {
+ ZEN_WARN("Failed to assign process to job object: {}", zen::GetLastError());
+ return false;
+ }
+
+ return true;
+}
+
+bool
+JobObject::IsValid() const
+{
+ return m_JobHandle != nullptr;
+}
+#endif // ZEN_PLATFORM_WINDOWS
+
+//////////////////////////////////////////////////////////////////////////
+
bool
IsProcessRunning(int pid, std::error_code& OutEc)
{
@@ -1001,6 +1104,232 @@ GetProcessExecutablePath(int Pid, std::error_code& OutEc)
#endif // ZEN_PLATFORM_LINUX
}
+std::string
+GetProcessCommandLine(int Pid, std::error_code& OutEc)
+{
+#if ZEN_PLATFORM_WINDOWS
+ HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, static_cast<DWORD>(Pid));
+ if (!hProcess)
+ {
+ OutEc = MakeErrorCodeFromLastError();
+ return {};
+ }
+ auto _ = MakeGuard([hProcess] { CloseHandle(hProcess); });
+
+ // NtQueryInformationProcess is an undocumented NT API; load it dynamically.
+ // Info class 60 = ProcessCommandLine, available since Windows 8.1.
+ using PFN_NtQIP = LONG(WINAPI*)(HANDLE, UINT, PVOID, ULONG, PULONG);
+ static const PFN_NtQIP s_NtQIP =
+ reinterpret_cast<PFN_NtQIP>(GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "NtQueryInformationProcess"));
+ if (!s_NtQIP)
+ {
+ return {};
+ }
+
+ constexpr UINT ProcessCommandLineClass = 60;
+ constexpr LONG StatusInfoLengthMismatch = static_cast<LONG>(0xC0000004L);
+
+ ULONG ReturnLength = 0;
+ LONG Status = s_NtQIP(hProcess, ProcessCommandLineClass, nullptr, 0, &ReturnLength);
+ if (Status != StatusInfoLengthMismatch || ReturnLength == 0)
+ {
+ return {};
+ }
+
+ std::vector<char> Buf(ReturnLength);
+ Status = s_NtQIP(hProcess, ProcessCommandLineClass, Buf.data(), ReturnLength, &ReturnLength);
+ if (Status < 0)
+ {
+ OutEc = MakeErrorCodeFromLastError();
+ return {};
+ }
+
+ // Output: UNICODE_STRING header immediately followed by the UTF-16 string data.
+ // The UNICODE_STRING.Buffer field points into our Buf.
+ struct LocalUnicodeString
+ {
+ USHORT Length;
+ USHORT MaximumLength;
+ WCHAR* Buffer;
+ };
+ if (ReturnLength < sizeof(LocalUnicodeString))
+ {
+ return {};
+ }
+ const auto* Us = reinterpret_cast<const LocalUnicodeString*>(Buf.data());
+ if (Us->Length == 0 || Us->Buffer == nullptr)
+ {
+ return {};
+ }
+
+ // Skip argv[0]: may be a quoted path ("C:\...\exe.exe") or a bare path
+ const WCHAR* p = Us->Buffer;
+ const WCHAR* End = Us->Buffer + Us->Length / sizeof(WCHAR);
+ if (p < End && *p == L'"')
+ {
+ ++p;
+ while (p < End && *p != L'"')
+ {
+ ++p;
+ }
+ if (p < End)
+ {
+ ++p; // skip closing quote
+ }
+ }
+ else
+ {
+ while (p < End && *p != L' ')
+ {
+ ++p;
+ }
+ }
+ while (p < End && *p == L' ')
+ {
+ ++p;
+ }
+ if (p >= End)
+ {
+ return {};
+ }
+
+ int Utf8Size = WideCharToMultiByte(CP_UTF8, 0, p, static_cast<int>(End - p), nullptr, 0, nullptr, nullptr);
+ if (Utf8Size <= 0)
+ {
+ OutEc = MakeErrorCodeFromLastError();
+ return {};
+ }
+ std::string Result(Utf8Size, '\0');
+ WideCharToMultiByte(CP_UTF8, 0, p, static_cast<int>(End - p), Result.data(), Utf8Size, nullptr, nullptr);
+ return Result;
+
+#elif ZEN_PLATFORM_LINUX
+ std::string CmdlinePath = fmt::format("/proc/{}/cmdline", Pid);
+ FILE* F = fopen(CmdlinePath.c_str(), "rb");
+ if (!F)
+ {
+ OutEc = MakeErrorCodeFromLastError();
+ return {};
+ }
+ auto FGuard = MakeGuard([F] { fclose(F); });
+
+ // /proc/{pid}/cmdline contains null-separated argv entries; read it all
+ std::string Raw;
+ char Chunk[4096];
+ size_t BytesRead;
+ while ((BytesRead = fread(Chunk, 1, sizeof(Chunk), F)) > 0)
+ {
+ Raw.append(Chunk, BytesRead);
+ }
+ if (Raw.empty())
+ {
+ return {};
+ }
+
+ // Skip argv[0] (first null-terminated entry)
+ const char* p = Raw.data();
+ const char* End = Raw.data() + Raw.size();
+ while (p < End && *p != '\0')
+ {
+ ++p;
+ }
+ if (p < End)
+ {
+ ++p; // skip null terminator of argv[0]
+ }
+
+ // Build result: remaining entries joined by spaces (inter-arg nulls → spaces)
+ std::string Result;
+ Result.reserve(static_cast<size_t>(End - p));
+ for (const char* q = p; q < End; ++q)
+ {
+ Result += (*q == '\0') ? ' ' : *q;
+ }
+ while (!Result.empty() && Result.back() == ' ')
+ {
+ Result.pop_back();
+ }
+ return Result;
+
+#elif ZEN_PLATFORM_MAC
+ int Mib[3] = {CTL_KERN, KERN_PROCARGS2, Pid};
+ size_t BufSize = 0;
+ if (sysctl(Mib, 3, nullptr, &BufSize, nullptr, 0) != 0 || BufSize == 0)
+ {
+ OutEc = MakeErrorCodeFromLastError();
+ return {};
+ }
+
+ std::vector<char> Buf(BufSize);
+ if (sysctl(Mib, 3, Buf.data(), &BufSize, nullptr, 0) != 0)
+ {
+ OutEc = MakeErrorCodeFromLastError();
+ return {};
+ }
+
+ // Layout: [int argc][exec_path\0][null padding][argv[0]\0][argv[1]\0]...[envp\0]...
+ if (BufSize < sizeof(int))
+ {
+ return {};
+ }
+ int Argc = 0;
+ memcpy(&Argc, Buf.data(), sizeof(int));
+ if (Argc <= 1)
+ {
+ return {};
+ }
+
+ const char* p = Buf.data() + sizeof(int);
+ const char* End = Buf.data() + BufSize;
+
+ // Skip exec_path and any trailing null padding that follows it
+ while (p < End && *p != '\0')
+ {
+ ++p;
+ }
+ while (p < End && *p == '\0')
+ {
+ ++p;
+ }
+
+ // Skip argv[0]
+ while (p < End && *p != '\0')
+ {
+ ++p;
+ }
+ if (p < End)
+ {
+ ++p;
+ }
+
+ // Collect argv[1..Argc-1]
+ std::string Result;
+ for (int i = 1; i < Argc && p < End; ++i)
+ {
+ if (i > 1)
+ {
+ Result += ' ';
+ }
+ const char* ArgStart = p;
+ while (p < End && *p != '\0')
+ {
+ ++p;
+ }
+ Result.append(ArgStart, p);
+ if (p < End)
+ {
+ ++p;
+ }
+ }
+ return Result;
+
+#else
+ ZEN_UNUSED(Pid);
+ ZEN_UNUSED(OutEc);
+ return {};
+#endif
+}
+
std::error_code
FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf)
{
diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp
index a6a86ee12..f19afe715 100644
--- a/src/zencore/refcount.cpp
+++ b/src/zencore/refcount.cpp
@@ -33,6 +33,8 @@ refcount_forcelink()
{
}
+TEST_SUITE_BEGIN("core.refcount");
+
TEST_CASE("RefPtr")
{
RefPtr<TestRefClass> Ref;
@@ -60,6 +62,8 @@ TEST_CASE("RefPtr")
CHECK(IsDestroyed == true);
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp
index 00e67dc85..8d087e8c6 100644
--- a/src/zencore/sentryintegration.cpp
+++ b/src/zencore/sentryintegration.cpp
@@ -4,29 +4,23 @@
#include <zencore/config.h>
#include <zencore/logging.h>
+#include <zencore/logging/registry.h>
+#include <zencore/logging/sink.h>
#include <zencore/session.h>
#include <zencore/uid.h>
#include <stdarg.h>
#include <stdio.h>
-#if ZEN_PLATFORM_LINUX
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
# include <pwd.h>
+# include <unistd.h>
#endif
-#if ZEN_PLATFORM_MAC
-# include <pwd.h>
-#endif
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/spdlog.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
#if ZEN_USE_SENTRY
# define SENTRY_BUILD_STATIC 1
ZEN_THIRD_PARTY_INCLUDES_START
# include <sentry.h>
-# include <spdlog/sinks/base_sink.h>
ZEN_THIRD_PARTY_INCLUDES_END
namespace sentry {
@@ -44,71 +38,58 @@ struct SentryAssertImpl : zen::AssertImpl
const zen::CallstackFrames* Callstack) override;
};
-class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex>
+static constexpr sentry_level_t MapToSentryLevel[zen::logging::LogLevelCount] = {SENTRY_LEVEL_DEBUG,
+ SENTRY_LEVEL_DEBUG,
+ SENTRY_LEVEL_INFO,
+ SENTRY_LEVEL_WARNING,
+ SENTRY_LEVEL_ERROR,
+ SENTRY_LEVEL_FATAL,
+ SENTRY_LEVEL_DEBUG};
+
+class SentrySink final : public zen::logging::Sink
{
public:
- sentry_sink();
- ~sentry_sink();
+ SentrySink() = default;
+ ~SentrySink() = default;
-protected:
- void sink_it_(const spdlog::details::log_msg& msg) override;
- void flush_() override;
+ void Log(const zen::logging::LogMessage& Msg) override
+ {
+ if (Msg.GetLevel() != zen::logging::Err && Msg.GetLevel() != zen::logging::Critical)
+ {
+ return;
+ }
+ try
+ {
+ std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Msg.GetSource().Filename, Msg.GetSource().Line);
+ sentry_value_t Event = sentry_value_new_message_event(
+ /* level */ MapToSentryLevel[Msg.GetLevel()],
+ /* logger */ nullptr,
+ /* message */ Message.c_str());
+ sentry_event_value_add_stacktrace(Event, NULL, 0);
+ sentry_capture_event(Event);
+ }
+ catch (const std::exception&)
+ {
+ // If our logging with Message formatting fails we do a non-allocating version and just post the payload raw
+ char TmpBuffer[256];
+ size_t MaxCopy = zen::Min<size_t>(Msg.GetPayload().size(), size_t(255));
+ memcpy(TmpBuffer, Msg.GetPayload().data(), MaxCopy);
+ TmpBuffer[MaxCopy] = '\0';
+ sentry_value_t Event = sentry_value_new_message_event(
+ /* level */ SENTRY_LEVEL_ERROR,
+ /* logger */ nullptr,
+ /* message */ TmpBuffer);
+ sentry_event_value_add_stacktrace(Event, NULL, 0);
+ sentry_capture_event(Event);
+ }
+ }
+
+ void Flush() override {}
+ void SetFormatter(std::unique_ptr<zen::logging::Formatter>) override {}
};
//////////////////////////////////////////////////////////////////////////
-static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG,
- SENTRY_LEVEL_DEBUG,
- SENTRY_LEVEL_INFO,
- SENTRY_LEVEL_WARNING,
- SENTRY_LEVEL_ERROR,
- SENTRY_LEVEL_FATAL,
- SENTRY_LEVEL_DEBUG};
-
-sentry_sink::sentry_sink()
-{
-}
-sentry_sink::~sentry_sink()
-{
-}
-
-void
-sentry_sink::sink_it_(const spdlog::details::log_msg& msg)
-{
- if (msg.level != spdlog::level::err && msg.level != spdlog::level::critical)
- {
- return;
- }
- try
- {
- std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname);
- sentry_value_t event = sentry_value_new_message_event(
- /* level */ MapToSentryLevel[msg.level],
- /* logger */ nullptr,
- /* message */ Message.c_str());
- sentry_event_value_add_stacktrace(event, NULL, 0);
- sentry_capture_event(event);
- }
- catch (const std::exception&)
- {
- // If our logging with Message formatting fails we do a non-allocating version and just post the msg.payload raw
- char TmpBuffer[256];
- size_t MaxCopy = zen::Min<size_t>(msg.payload.size(), size_t(255));
- memcpy(TmpBuffer, msg.payload.data(), MaxCopy);
- TmpBuffer[MaxCopy] = '\0';
- sentry_value_t event = sentry_value_new_message_event(
- /* level */ SENTRY_LEVEL_ERROR,
- /* logger */ nullptr,
- /* message */ TmpBuffer);
- sentry_event_value_add_stacktrace(event, NULL, 0);
- sentry_capture_event(event);
- }
-}
-void
-sentry_sink::flush_()
-{
-}
-
void
SentryAssertImpl::OnAssert(const char* Filename,
int LineNumber,
@@ -145,6 +126,10 @@ SentryAssertImpl::OnAssert(const char* Filename,
namespace zen {
# if ZEN_USE_SENTRY
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogSentry, "sentry-sdk");
+
+static std::atomic<bool> s_SentryLogEnabled{true};
+
static void
SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata)
{
@@ -163,26 +148,62 @@ SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[may
MessagePtr = LogMessage.c_str();
}
+ // SentryLogFunction can be called before the logging system is initialized
+ // (during sentry_init which runs before InitializeLogging), or after it has
+ // been shut down (during sentry_close on a background worker thread). Fall
+ // back to console logging when the category logger is not available.
+ //
+ // Since we want to default to WARN level but this runs before logging has
+ // been configured, we ignore the callbacks for DEBUG/INFO explicitly here
+ // which means users don't see every possible log message if they're trying
+ // to configure the levels using --log-debug=sentry-sdk
+ if (!TheDefaultLogger || !s_SentryLogEnabled.load(std::memory_order_acquire))
+ {
+ switch (Level)
+ {
+ case SENTRY_LEVEL_DEBUG:
+ // ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_INFO:
+ // ZEN_CONSOLE_INFO("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_WARNING:
+ ZEN_CONSOLE_WARN("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_ERROR:
+ ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_FATAL:
+ ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr);
+ break;
+ }
+ return;
+ }
+
switch (Level)
{
case SENTRY_LEVEL_DEBUG:
- ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr);
+ ZEN_LOG_DEBUG(LogSentry, "sentry: {}", MessagePtr);
break;
case SENTRY_LEVEL_INFO:
- ZEN_CONSOLE_INFO("sentry: {}", MessagePtr);
+ ZEN_LOG_INFO(LogSentry, "sentry: {}", MessagePtr);
break;
case SENTRY_LEVEL_WARNING:
- ZEN_CONSOLE_WARN("sentry: {}", MessagePtr);
+ ZEN_LOG_WARN(LogSentry, "sentry: {}", MessagePtr);
break;
case SENTRY_LEVEL_ERROR:
- ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr);
+ ZEN_LOG_ERROR(LogSentry, "sentry: {}", MessagePtr);
break;
case SENTRY_LEVEL_FATAL:
- ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr);
+ ZEN_LOG_CRITICAL(LogSentry, "sentry: {}", MessagePtr);
break;
}
}
@@ -194,11 +215,21 @@ SentryIntegration::SentryIntegration()
SentryIntegration::~SentryIntegration()
{
+ Close();
+}
+
+void
+SentryIntegration::Close()
+{
if (m_IsInitialized && m_SentryErrorCode == 0)
{
logging::SetErrorLog("");
m_SentryAssert.reset();
+ // Disable spdlog forwarding before sentry_close() since its background
+ // worker thread may still log during shutdown via SentryLogFunction
+ s_SentryLogEnabled.store(false, std::memory_order_release);
sentry_close();
+ m_IsInitialized = false;
}
}
@@ -298,7 +329,9 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine
sentry_set_user(SentryUserObject);
- m_SentryLogger = spdlog::create<sentry::sentry_sink>("sentry");
+ logging::SinkPtr SentrySink(new sentry::SentrySink());
+ m_SentryLogger = Ref<logging::Logger>(new logging::Logger("sentry", std::vector<logging::SinkPtr>{SentrySink}));
+ logging::Registry::Instance().Register(m_SentryLogger);
logging::SetErrorLog("sentry");
m_SentryAssert = std::make_unique<sentry::SentryAssertImpl>();
@@ -310,22 +343,31 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine
void
SentryIntegration::LogStartupInformation()
{
+ // Initialize the sentry-sdk log category at Warn level to reduce startup noise.
+ // The level can be overridden via --log-debug=sentry-sdk or --log-info=sentry-sdk
+ LogSentry.Logger().SetLogLevel(logging::Warn);
+
if (m_IsInitialized)
{
if (m_SentryErrorCode == 0)
{
if (m_AllowPII)
{
- ZEN_INFO("sentry initialized, username: '{}', hostname: '{}', id: '{}'", m_SentryUserName, m_SentryHostName, m_SentryId);
+ ZEN_LOG_INFO(LogSentry,
+ "sentry initialized, username: '{}', hostname: '{}', id: '{}'",
+ m_SentryUserName,
+ m_SentryHostName,
+ m_SentryId);
}
else
{
- ZEN_INFO("sentry initialized with anonymous reports");
+ ZEN_LOG_INFO(LogSentry, "sentry initialized with anonymous reports");
}
}
else
{
- ZEN_WARN(
+ ZEN_LOG_WARN(
+ LogSentry,
"sentry_init returned failure! (error code: {}) note that sentry expects crashpad_handler to exist alongside the running "
"executable",
m_SentryErrorCode);
diff --git a/src/zencore/sha1.cpp b/src/zencore/sha1.cpp
index 3ee74d7d8..807ae4c30 100644
--- a/src/zencore/sha1.cpp
+++ b/src/zencore/sha1.cpp
@@ -373,6 +373,8 @@ sha1_forcelink()
// return sha1text;
// }
+TEST_SUITE_BEGIN("core.sha1");
+
TEST_CASE("SHA1")
{
uint8_t sha1_empty[20] = {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55,
@@ -438,6 +440,8 @@ TEST_CASE("SHA1")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp
index 78efb9d42..8dc6d49d8 100644
--- a/src/zencore/sharedbuffer.cpp
+++ b/src/zencore/sharedbuffer.cpp
@@ -152,10 +152,14 @@ sharedbuffer_forcelink()
{
}
+TEST_SUITE_BEGIN("core.sharedbuffer");
+
TEST_CASE("SharedBuffer")
{
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/stream.cpp b/src/zencore/stream.cpp
index a800ce121..de67303a4 100644
--- a/src/zencore/stream.cpp
+++ b/src/zencore/stream.cpp
@@ -79,6 +79,8 @@ BufferReader::Serialize(void* V, int64_t Length)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.stream");
+
TEST_CASE("binary.writer.span")
{
BinaryWriter Writer;
@@ -91,6 +93,8 @@ TEST_CASE("binary.writer.span")
CHECK(memcmp(Result.GetData(), "apa banan", 9) == 0);
}
+TEST_SUITE_END();
+
void
stream_forcelink()
{
diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp
index 0ee863b74..ed0ba6f46 100644
--- a/src/zencore/string.cpp
+++ b/src/zencore/string.cpp
@@ -4,6 +4,7 @@
#include <zencore/memoryview.h>
#include <zencore/string.h>
#include <zencore/testing.h>
+#include <zencore/testutils.h>
#include <inttypes.h>
#include <math.h>
@@ -24,6 +25,10 @@ utf16to8_impl(u16bit_iterator StartIt, u16bit_iterator EndIt, ::zen::StringBuild
// Take care of surrogate pairs first
if (utf8::internal::is_lead_surrogate(cp))
{
+ if (StartIt == EndIt)
+ {
+ break;
+ }
uint32_t trail_surrogate = utf8::internal::mask16(*StartIt++);
cp = (cp << 10) + trail_surrogate + utf8::internal::SURROGATE_OFFSET;
}
@@ -180,7 +185,21 @@ Utf8ToWide(const std::u8string_view& Str8, WideStringBuilderBase& OutString)
if (!ByteCount)
{
+#if ZEN_SIZEOF_WCHAR_T == 2
+ if (CurrentOutChar > 0xFFFF)
+ {
+ // Supplementary plane: emit a UTF-16 surrogate pair
+ uint32_t Adjusted = uint32_t(CurrentOutChar - 0x10000);
+ OutString.Append(wchar_t(0xD800 + (Adjusted >> 10)));
+ OutString.Append(wchar_t(0xDC00 + (Adjusted & 0x3FF)));
+ }
+ else
+ {
+ OutString.Append(wchar_t(CurrentOutChar));
+ }
+#else
OutString.Append(wchar_t(CurrentOutChar));
+#endif
CurrentOutChar = 0;
}
}
@@ -249,6 +268,17 @@ namespace {
/* kNicenumTime */ 1000};
} // namespace
+uint64_t
+IntPow(uint64_t Base, int Exp)
+{
+ uint64_t Result = 1;
+ for (int I = 0; I < Exp; ++I)
+ {
+ Result *= Base;
+ }
+ return Result;
+}
+
/*
* Convert a number to an appropriately human-readable output.
*/
@@ -296,7 +326,7 @@ NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format)
const char* u = UnitStrings[Format][Index];
- if ((Index == 0) || ((Num % (uint64_t)powl((int)KiloUnit[Format], Index)) == 0))
+ if ((Index == 0) || ((Num % IntPow(KiloUnit[Format], Index)) == 0))
{
/*
* If this is an even multiple of the base, always display
@@ -320,7 +350,7 @@ NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format)
for (int i = 2; i >= 0; i--)
{
- double Value = (double)Num / (uint64_t)powl((int)KiloUnit[Format], Index);
+ double Value = (double)Num / IntPow(KiloUnit[Format], Index);
/*
* Don't print floating point values for time. Note,
@@ -520,13 +550,38 @@ UrlDecode(std::string_view InUrl)
return std::string(Url.ToView());
}
-//////////////////////////////////////////////////////////////////////////
-//
-// Unit tests
-//
+std::string
+HideSensitiveString(std::string_view String)
+{
+ const size_t Length = String.length();
+ const size_t SourceLength = Length > 16 ? 4 : 0;
+ const size_t PadLength = Min(Length - SourceLength, 4u);
+ const bool AddEllipsis = (SourceLength + PadLength) < Length;
+ StringBuilder<16> SB;
+ if (SourceLength > 0)
+ {
+ SB << String.substr(0, SourceLength);
+ }
+ if (PadLength > 0)
+ {
+ SB << std::string(PadLength, 'X');
+ }
+ if (AddEllipsis)
+ {
+ SB << "...";
+ }
+ return SB.ToString();
+};
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Unit tests
+ //
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.string");
+
TEST_CASE("url")
{
using namespace std::literals;
@@ -793,11 +848,6 @@ TEST_CASE("niceNum")
}
}
-void
-string_forcelink()
-{
-}
-
TEST_CASE("StringBuilder")
{
StringBuilder<64> sb;
@@ -963,33 +1013,131 @@ TEST_CASE("ExtendableWideStringBuilder")
TEST_CASE("utf8")
{
+ using namespace utf8test;
+
SUBCASE("utf8towide")
{
- // TODO: add more extensive testing here - this covers a very small space
-
WideStringBuilder<32> wout;
Utf8ToWide(u8"abcdefghi", wout);
CHECK(StringEquals(L"abcdefghi", wout.c_str()));
wout.Reset();
+ Utf8ToWide(u8"abc\xC3\xA4\xC3\xB6\xC3\xBC", wout);
+ CHECK(StringEquals(L"abc\u00E4\u00F6\u00FC", wout.c_str()));
+
+ wout.Reset();
+ Utf8ToWide(std::string_view(kLatin), wout);
+ CHECK(StringEquals(kLatinW, wout.c_str()));
+
+ wout.Reset();
+ Utf8ToWide(std::string_view(kCyrillic), wout);
+ CHECK(StringEquals(kCyrillicW, wout.c_str()));
+
+ wout.Reset();
+ Utf8ToWide(std::string_view(kCJK), wout);
+ CHECK(StringEquals(kCJKW, wout.c_str()));
+
+ wout.Reset();
+ Utf8ToWide(std::string_view(kMixed), wout);
+ CHECK(StringEquals(kMixedW, wout.c_str()));
- Utf8ToWide(u8"abc���", wout);
- CHECK(StringEquals(L"abc���", wout.c_str()));
+ wout.Reset();
+ Utf8ToWide(std::string_view(kEmoji), wout);
+ CHECK(StringEquals(kEmojiW, wout.c_str()));
}
SUBCASE("widetoutf8")
{
- // TODO: add more extensive testing here - this covers a very small space
-
- StringBuilder<32> out;
+ StringBuilder<64> out;
WideToUtf8(L"abcdefghi", out);
CHECK(StringEquals("abcdefghi", out.c_str()));
out.Reset();
+ WideToUtf8(kLatinW, out);
+ CHECK(StringEquals(kLatin, out.c_str()));
+
+ out.Reset();
+ WideToUtf8(kCyrillicW, out);
+ CHECK(StringEquals(kCyrillic, out.c_str()));
+
+ out.Reset();
+ WideToUtf8(kCJKW, out);
+ CHECK(StringEquals(kCJK, out.c_str()));
+
+ out.Reset();
+ WideToUtf8(kMixedW, out);
+ CHECK(StringEquals(kMixed, out.c_str()));
- WideToUtf8(L"abc���", out);
- CHECK(StringEquals(u8"abc���", out.c_str()));
+ out.Reset();
+ WideToUtf8(kEmojiW, out);
+ CHECK(StringEquals(kEmoji, out.c_str()));
+ }
+
+ SUBCASE("roundtrip")
+ {
+ // UTF-8 -> Wide -> UTF-8 identity
+ const char* Utf8Strings[] = {kLatin, kCyrillic, kCJK, kMixed, kEmoji};
+ for (const char* Utf8Str : Utf8Strings)
+ {
+ ExtendableWideStringBuilder<64> Wide;
+ Utf8ToWide(std::string_view(Utf8Str), Wide);
+
+ ExtendableStringBuilder<64> Back;
+ WideToUtf8(std::wstring_view(Wide.c_str()), Back);
+ CHECK(StringEquals(Utf8Str, Back.c_str()));
+ }
+
+ // Wide -> UTF-8 -> Wide identity
+ const wchar_t* WideStrings[] = {kLatinW, kCyrillicW, kCJKW, kMixedW, kEmojiW};
+ for (const wchar_t* WideStr : WideStrings)
+ {
+ ExtendableStringBuilder<64> Utf8;
+ WideToUtf8(std::wstring_view(WideStr), Utf8);
+
+ ExtendableWideStringBuilder<64> Back;
+ Utf8ToWide(std::string_view(Utf8.c_str()), Back);
+ CHECK(StringEquals(WideStr, Back.c_str()));
+ }
+
+ // Empty string round-trip
+ {
+ ExtendableWideStringBuilder<8> Wide;
+ Utf8ToWide(std::string_view(""), Wide);
+ CHECK(Wide.Size() == 0);
+
+ ExtendableStringBuilder<8> Narrow;
+ WideToUtf8(std::wstring_view(L""), Narrow);
+ CHECK(Narrow.Size() == 0);
+ }
+ }
+
+ SUBCASE("IsValidUtf8")
+ {
+ // Valid inputs
+ CHECK(IsValidUtf8(""));
+ CHECK(IsValidUtf8("hello world"));
+ CHECK(IsValidUtf8(kLatin));
+ CHECK(IsValidUtf8(kCyrillic));
+ CHECK(IsValidUtf8(kCJK));
+ CHECK(IsValidUtf8(kMixed));
+ CHECK(IsValidUtf8(kEmoji));
+
+ // Invalid: truncated 2-byte sequence
+ CHECK(!IsValidUtf8(std::string_view("\xC3", 1)));
+
+ // Invalid: truncated 3-byte sequence
+ CHECK(!IsValidUtf8(std::string_view("\xE6\x97", 2)));
+
+ // Invalid: truncated 4-byte sequence
+ CHECK(!IsValidUtf8(std::string_view("\xF0\x9F\x93", 3)));
+
+ // Invalid: bad start byte
+ CHECK(!IsValidUtf8(std::string_view("\xFF", 1)));
+ CHECK(!IsValidUtf8(std::string_view("\xFE", 1)));
+
+ // Invalid: overlong encoding of '/' (U+002F)
+ CHECK(!IsValidUtf8(std::string_view("\xC0\xAF", 2)));
}
}
@@ -1105,6 +1253,28 @@ TEST_CASE("string")
}
}
+TEST_CASE("hidesensitivestring")
+{
+ using namespace std::literals;
+
+ CHECK_EQ(HideSensitiveString(""sv), ""sv);
+ CHECK_EQ(HideSensitiveString("A"sv), "X"sv);
+ CHECK_EQ(HideSensitiveString("ABCD"sv), "XXXX"sv);
+ CHECK_EQ(HideSensitiveString("ABCDE"sv), "XXXX..."sv);
+ CHECK_EQ(HideSensitiveString("ABCDEFGH"sv), "XXXX..."sv);
+ CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOP"sv), "XXXX..."sv);
+ CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOPQ"sv), "ABCDXXXX..."sv);
+ CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"sv), "ABCDXXXX..."sv);
+ CHECK_EQ(HideSensitiveString("1234567890123456789"sv), "1234XXXX..."sv);
+}
+
+TEST_SUITE_END();
+
+void
+string_forcelink()
+{
+}
+
#endif
} // namespace zen
diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp
index b9ac3bdee..141450b84 100644
--- a/src/zencore/system.cpp
+++ b/src/zencore/system.cpp
@@ -4,15 +4,20 @@
#include <zencore/compactbinarybuilder.h>
#include <zencore/except.h>
+#include <zencore/fmtutils.h>
#include <zencore/memory/memory.h>
#include <zencore/string.h>
+#include <mutex>
+
#if ZEN_PLATFORM_WINDOWS
# include <zencore/windows.h>
ZEN_THIRD_PARTY_INCLUDES_START
# include <iphlpapi.h>
# include <winsock2.h>
+# include <pdh.h>
+# pragma comment(lib, "pdh.lib")
ZEN_THIRD_PARTY_INCLUDES_END
#elif ZEN_PLATFORM_LINUX
# include <sys/utsname.h>
@@ -65,55 +70,73 @@ GetSystemMetrics()
// Determine physical core count
- DWORD BufferSize = 0;
- BOOL Result = GetLogicalProcessorInformation(nullptr, &BufferSize);
- if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER)
{
- ThrowSystemError(Error, "Failed to get buffer size for logical processor information");
- }
+ DWORD BufferSize = 0;
+ BOOL Result = GetLogicalProcessorInformationEx(RelationAll, nullptr, &BufferSize);
+ if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER)
+ {
+ ThrowSystemError(Error, "Failed to get buffer size for logical processor information");
+ }
- PSYSTEM_LOGICAL_PROCESSOR_INFORMATION Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION)Memory::Alloc(BufferSize);
+ PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)Memory::Alloc(BufferSize);
- Result = GetLogicalProcessorInformation(Buffer, &BufferSize);
- if (!Result)
- {
- Memory::Free(Buffer);
- throw std::runtime_error("Failed to get logical processor information");
- }
-
- DWORD ProcessorPkgCount = 0;
- DWORD ProcessorCoreCount = 0;
- DWORD ByteOffset = 0;
- while (ByteOffset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) <= BufferSize)
- {
- const SYSTEM_LOGICAL_PROCESSOR_INFORMATION& Slpi = Buffer[ByteOffset / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION)];
- if (Slpi.Relationship == RelationProcessorCore)
+ Result = GetLogicalProcessorInformationEx(RelationAll, Buffer, &BufferSize);
+ if (!Result)
{
- ProcessorCoreCount++;
+ Memory::Free(Buffer);
+ throw std::runtime_error("Failed to get logical processor information");
}
- else if (Slpi.Relationship == RelationProcessorPackage)
+
+ DWORD ProcessorPkgCount = 0;
+ DWORD ProcessorCoreCount = 0;
+ DWORD LogicalProcessorCount = 0;
+
+ BYTE* Ptr = reinterpret_cast<BYTE*>(Buffer);
+ BYTE* const End = Ptr + BufferSize;
+ while (Ptr < End)
{
- ProcessorPkgCount++;
+ const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX& Slpi = *reinterpret_cast<const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*>(Ptr);
+ if (Slpi.Relationship == RelationProcessorCore)
+ {
+ ++ProcessorCoreCount;
+
+ // Count logical processors (threads) across all processor groups for this core.
+ // Each core entry lists one GROUP_AFFINITY per group it spans; each set bit
+ // in the Mask represents one logical processor (HyperThreading sibling).
+ for (WORD g = 0; g < Slpi.Processor.GroupCount; ++g)
+ {
+ LogicalProcessorCount += static_cast<DWORD>(__popcnt64(Slpi.Processor.GroupMask[g].Mask));
+ }
+ }
+ else if (Slpi.Relationship == RelationProcessorPackage)
+ {
+ ++ProcessorPkgCount;
+ }
+ Ptr += Slpi.Size;
}
- ByteOffset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION);
- }
- Metrics.CoreCount = ProcessorCoreCount;
- Metrics.CpuCount = ProcessorPkgCount;
+ Metrics.CoreCount = ProcessorCoreCount;
+ Metrics.CpuCount = ProcessorPkgCount;
+ Metrics.LogicalProcessorCount = LogicalProcessorCount;
- Memory::Free(Buffer);
+ Memory::Free(Buffer);
+ }
// Query memory status
- MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)};
- GlobalMemoryStatusEx(&MemStatus);
+ {
+ MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)};
+ GlobalMemoryStatusEx(&MemStatus);
+
+ Metrics.SystemMemoryMiB = MemStatus.ullTotalPhys / 1024 / 1024;
+ Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024;
+ Metrics.VirtualMemoryMiB = MemStatus.ullTotalVirtual / 1024 / 1024;
+ Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024;
+ Metrics.PageFileMiB = MemStatus.ullTotalPageFile / 1024 / 1024;
+ Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024;
+ }
- Metrics.SystemMemoryMiB = MemStatus.ullTotalPhys / 1024 / 1024;
- Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024;
- Metrics.VirtualMemoryMiB = MemStatus.ullTotalVirtual / 1024 / 1024;
- Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024;
- Metrics.PageFileMiB = MemStatus.ullTotalPageFile / 1024 / 1024;
- Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024;
+ Metrics.UptimeSeconds = GetTickCount64() / 1000;
return Metrics;
}
@@ -206,6 +229,17 @@ GetSystemMetrics()
Metrics.VirtualMemoryMiB = Metrics.SystemMemoryMiB;
Metrics.AvailVirtualMemoryMiB = Metrics.AvailSystemMemoryMiB;
+ // System uptime
+ if (FILE* UptimeFile = fopen("/proc/uptime", "r"))
+ {
+ double UptimeSec = 0;
+ if (fscanf(UptimeFile, "%lf", &UptimeSec) == 1)
+ {
+ Metrics.UptimeSeconds = static_cast<uint64_t>(UptimeSec);
+ }
+ fclose(UptimeFile);
+ }
+
// Parse /proc/meminfo for swap/page file information
Metrics.PageFileMiB = 0;
Metrics.AvailPageFileMiB = 0;
@@ -298,12 +332,35 @@ GetSystemMetrics()
Metrics.PageFileMiB = SwapUsage.xsu_total / 1024 / 1024;
Metrics.AvailPageFileMiB = (SwapUsage.xsu_total - SwapUsage.xsu_used) / 1024 / 1024;
+ // System uptime via boot time
+ {
+ struct timeval BootTime
+ {
+ };
+ Size = sizeof(BootTime);
+ if (sysctlbyname("kern.boottime", &BootTime, &Size, nullptr, 0) == 0)
+ {
+ Metrics.UptimeSeconds = static_cast<uint64_t>(time(nullptr) - BootTime.tv_sec);
+ }
+ }
+
return Metrics;
}
#else
# error "Unknown platform"
#endif
+ExtendedSystemMetrics
+ApplyReportingOverrides(ExtendedSystemMetrics Metrics)
+{
+ if (g_FakeCpuCount)
+ {
+ Metrics.CoreCount = g_FakeCpuCount;
+ Metrics.LogicalProcessorCount = g_FakeCpuCount;
+ }
+ return Metrics;
+}
+
SystemMetrics
GetSystemMetricsForReporting()
{
@@ -318,12 +375,281 @@ GetSystemMetricsForReporting()
return Sm;
}
+///////////////////////////////////////////////////////////////////////////
+// SystemMetricsTracker
+///////////////////////////////////////////////////////////////////////////
+
+// Per-platform CPU sampling helper. Called with m_Mutex held.
+
+#if ZEN_PLATFORM_WINDOWS || ZEN_PLATFORM_LINUX
+
+// Samples CPU usage by reading /proc/stat. Used natively on Linux and as a
+// Wine fallback on Windows (where /proc/stat is accessible via the Z: drive).
+struct ProcStatCpuSampler
+{
+ const char* Path = "/proc/stat";
+ unsigned long PrevUser = 0;
+ unsigned long PrevNice = 0;
+ unsigned long PrevSystem = 0;
+ unsigned long PrevIdle = 0;
+ unsigned long PrevIoWait = 0;
+ unsigned long PrevIrq = 0;
+ unsigned long PrevSoftIrq = 0;
+
+ explicit ProcStatCpuSampler(const char* InPath = "/proc/stat") : Path(InPath) {}
+
+ float Sample()
+ {
+ float CpuUsage = 0.0f;
+
+ if (FILE* Stat = fopen(Path, "r"))
+ {
+ char Line[256];
+ unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq;
+
+ if (fgets(Line, sizeof(Line), Stat))
+ {
+ if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7)
+ {
+ unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) -
+ (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq);
+ unsigned long IdleDelta = Idle - PrevIdle;
+
+ if (TotalDelta > 0)
+ {
+ CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta;
+ }
+
+ PrevUser = User;
+ PrevNice = Nice;
+ PrevSystem = System;
+ PrevIdle = Idle;
+ PrevIoWait = IoWait;
+ PrevIrq = Irq;
+ PrevSoftIrq = SoftIrq;
+ }
+ }
+ fclose(Stat);
+ }
+
+ return CpuUsage;
+ }
+};
+
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+
+struct CpuSampler
+{
+ PDH_HQUERY QueryHandle = nullptr;
+ PDH_HCOUNTER CounterHandle = nullptr;
+ bool HasPreviousSample = false;
+ bool IsWine = false;
+ ProcStatCpuSampler ProcStat{"Z:\\proc\\stat"};
+
+ CpuSampler()
+ {
+ IsWine = zen::windows::IsRunningOnWine();
+
+ if (!IsWine)
+ {
+ if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS)
+ {
+ if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) != ERROR_SUCCESS)
+ {
+ CounterHandle = nullptr;
+ }
+ }
+ }
+ }
+
+ ~CpuSampler()
+ {
+ if (QueryHandle)
+ {
+ PdhCloseQuery(QueryHandle);
+ }
+ }
+
+ float Sample()
+ {
+ if (IsWine)
+ {
+ return ProcStat.Sample();
+ }
+
+ if (!QueryHandle || !CounterHandle)
+ {
+ return 0.0f;
+ }
+
+ PdhCollectQueryData(QueryHandle);
+
+ if (!HasPreviousSample)
+ {
+ HasPreviousSample = true;
+ return 0.0f;
+ }
+
+ PDH_FMT_COUNTERVALUE CounterValue;
+ if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS)
+ {
+ return static_cast<float>(CounterValue.doubleValue);
+ }
+
+ return 0.0f;
+ }
+};
+
+#elif ZEN_PLATFORM_LINUX
+
+struct CpuSampler
+{
+ ProcStatCpuSampler ProcStat;
+
+ float Sample() { return ProcStat.Sample(); }
+};
+
+#elif ZEN_PLATFORM_MAC
+
+struct CpuSampler
+{
+ unsigned long PrevTotalTicks = 0;
+ unsigned long PrevIdleTicks = 0;
+
+ float Sample()
+ {
+ float CpuUsage = 0.0f;
+
+ host_cpu_load_info_data_t CpuLoad;
+ mach_msg_type_number_t Count = sizeof(CpuLoad) / sizeof(natural_t);
+ if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &Count) == KERN_SUCCESS)
+ {
+ unsigned long TotalTicks = 0;
+ for (int i = 0; i < CPU_STATE_MAX; ++i)
+ {
+ TotalTicks += CpuLoad.cpu_ticks[i];
+ }
+ unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE];
+
+ unsigned long TotalDelta = TotalTicks - PrevTotalTicks;
+ unsigned long IdleDelta = IdleTicks - PrevIdleTicks;
+
+ if (TotalDelta > 0 && PrevTotalTicks > 0)
+ {
+ CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta;
+ }
+
+ PrevTotalTicks = TotalTicks;
+ PrevIdleTicks = IdleTicks;
+ }
+
+ return CpuUsage;
+ }
+};
+
+#endif
+
+struct SystemMetricsTracker::Impl
+{
+ using Clock = std::chrono::steady_clock;
+
+ std::mutex Mutex;
+ CpuSampler Sampler;
+ float CachedCpuPercent = 0.0f;
+ Clock::time_point NextSampleTime = Clock::now();
+ std::chrono::milliseconds MinInterval;
+
+ explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) {}
+
+ float SampleCpu()
+ {
+ const auto Now = Clock::now();
+ if (Now >= NextSampleTime)
+ {
+ CachedCpuPercent = Sampler.Sample();
+ NextSampleTime = Now + MinInterval;
+ }
+ return CachedCpuPercent;
+ }
+};
+
+SystemMetricsTracker::SystemMetricsTracker(std::chrono::milliseconds MinInterval) : m_Impl(std::make_unique<Impl>(MinInterval))
+{
+}
+
+SystemMetricsTracker::~SystemMetricsTracker() = default;
+
+ExtendedSystemMetrics
+SystemMetricsTracker::Query()
+{
+ ExtendedSystemMetrics Metrics;
+ static_cast<SystemMetrics&>(Metrics) = GetSystemMetrics();
+
+ std::lock_guard Lock(m_Impl->Mutex);
+ Metrics.CpuUsagePercent = m_Impl->SampleCpu();
+ return Metrics;
+}
+
+///////////////////////////////////////////////////////////////////////////
+
std::string_view
GetOperatingSystemName()
{
return ZEN_PLATFORM_NAME;
}
+std::string
+GetOperatingSystemVersion()
+{
+#if ZEN_PLATFORM_WINDOWS
+ // Use RtlGetVersion to avoid the compatibility shim that GetVersionEx applies
+ using RtlGetVersionFn = LONG(WINAPI*)(PRTL_OSVERSIONINFOW);
+ RTL_OSVERSIONINFOW OsVer{.dwOSVersionInfoSize = sizeof(OsVer)};
+ if (auto Fn = (RtlGetVersionFn)GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "RtlGetVersion"))
+ {
+ Fn(&OsVer);
+ }
+ return fmt::format("Windows {}.{} Build {}", OsVer.dwMajorVersion, OsVer.dwMinorVersion, OsVer.dwBuildNumber);
+#elif ZEN_PLATFORM_LINUX
+ struct utsname Info
+ {
+ };
+ if (uname(&Info) == 0)
+ {
+ return fmt::format("{} {}", Info.sysname, Info.release);
+ }
+ return "Linux";
+#elif ZEN_PLATFORM_MAC
+ char OsVersion[64] = "";
+ size_t Size = sizeof(OsVersion);
+ if (sysctlbyname("kern.osproductversion", OsVersion, &Size, nullptr, 0) == 0)
+ {
+ return fmt::format("macOS {}", OsVersion);
+ }
+ return "macOS";
+#endif
+}
+
+std::string_view
+GetRuntimePlatformName()
+{
+#if ZEN_PLATFORM_WINDOWS
+ if (zen::windows::IsRunningOnWine())
+ {
+ return "wine"sv;
+ }
+ return "windows"sv;
+#elif ZEN_PLATFORM_LINUX
+ return "linux"sv;
+#elif ZEN_PLATFORM_MAC
+ return "macos"sv;
+#else
+ return "unknown"sv;
+#endif
+}
+
std::string_view
GetCpuName()
{
@@ -340,7 +666,14 @@ Describe(const SystemMetrics& Metrics, CbWriter& Writer)
Writer << "cpu_count" << Metrics.CpuCount << "core_count" << Metrics.CoreCount << "lp_count" << Metrics.LogicalProcessorCount
<< "total_memory_mb" << Metrics.SystemMemoryMiB << "avail_memory_mb" << Metrics.AvailSystemMemoryMiB << "total_virtual_mb"
<< Metrics.VirtualMemoryMiB << "avail_virtual_mb" << Metrics.AvailVirtualMemoryMiB << "total_pagefile_mb" << Metrics.PageFileMiB
- << "avail_pagefile_mb" << Metrics.AvailPageFileMiB;
+ << "avail_pagefile_mb" << Metrics.AvailPageFileMiB << "uptime_seconds" << Metrics.UptimeSeconds;
+}
+
+void
+Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer)
+{
+ Describe(static_cast<const SystemMetrics&>(Metrics), Writer);
+ Writer << "cpu_usage_percent" << Metrics.CpuUsagePercent;
}
} // namespace zen
diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp
index 936424e0f..089e376bb 100644
--- a/src/zencore/testing.cpp
+++ b/src/zencore/testing.cpp
@@ -1,11 +1,22 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#define ZEN_TEST_WITH_RUNNER 1
+
#include "zencore/testing.h"
+
+#include "zencore/filesystem.h"
#include "zencore/logging.h"
+#include "zencore/process.h"
+#include "zencore/trace.h"
#if ZEN_WITH_TESTS
-# include <doctest/doctest.h>
+# include <chrono>
+# include <clocale>
+# include <cstdlib>
+# include <cstdio>
+# include <string>
+# include <vector>
namespace zen::testing {
@@ -21,9 +32,35 @@ struct TestListener : public doctest::IReporter
void report_query(const doctest::QueryData& /*in*/) override {}
- void test_run_start() override {}
+ void test_run_start() override { RunStart = std::chrono::steady_clock::now(); }
- void test_run_end(const doctest::TestRunStats& /*in*/) override {}
+ void test_run_end(const doctest::TestRunStats& in) override
+ {
+ auto elapsed = std::chrono::steady_clock::now() - RunStart;
+ double elapsedSeconds = std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count() / 1000.0;
+
+ // Write machine-readable summary to file if requested (used by xmake test summary table)
+ const char* summaryFile = std::getenv("ZEN_TEST_SUMMARY_FILE");
+ if (summaryFile && summaryFile[0] != '\0')
+ {
+ if (FILE* f = std::fopen(summaryFile, "w"))
+ {
+ std::fprintf(f,
+ "cases_total=%u\ncases_passed=%u\nassertions_total=%d\nassertions_passed=%d\n"
+ "elapsed_seconds=%.3f\n",
+ in.numTestCasesPassingFilters,
+ in.numTestCasesPassingFilters - in.numTestCasesFailed,
+ in.numAsserts,
+ in.numAsserts - in.numAssertsFailed,
+ elapsedSeconds);
+ for (const auto& failure : FailedTests)
+ {
+ std::fprintf(f, "failed=%s|%s|%u\n", failure.Name.c_str(), failure.File.c_str(), failure.Line);
+ }
+ std::fclose(f);
+ }
+ }
+ }
void test_case_start(const doctest::TestCaseData& in) override
{
@@ -37,7 +74,14 @@ struct TestListener : public doctest::IReporter
ZEN_CONSOLE("{}-------------------------------------------------------------------------------{}", ColorYellow, ColorNone);
}
- void test_case_end(const doctest::CurrentTestCaseStats& /*in*/) override { Current = nullptr; }
+ void test_case_end(const doctest::CurrentTestCaseStats& in) override
+ {
+ if (!in.testCaseSuccess && Current)
+ {
+ FailedTests.push_back({Current->m_name, Current->m_file.c_str(), Current->m_line});
+ }
+ Current = nullptr;
+ }
void test_case_exception(const doctest::TestCaseException& /*in*/) override {}
@@ -57,7 +101,16 @@ struct TestListener : public doctest::IReporter
void test_case_skipped(const doctest::TestCaseData& /*in*/) override {}
- const doctest::TestCaseData* Current = nullptr;
+ const doctest::TestCaseData* Current = nullptr;
+ std::chrono::steady_clock::time_point RunStart = {};
+
+ struct FailedTestInfo
+ {
+ std::string Name;
+ std::string File;
+ unsigned Line;
+ };
+ std::vector<FailedTestInfo> FailedTests;
};
struct TestRunner::Impl
@@ -75,20 +128,26 @@ TestRunner::~TestRunner()
{
}
+void
+TestRunner::SetDefaultSuiteFilter(const char* Pattern)
+{
+ m_Impl->Session.setOption("test-suite", Pattern);
+}
+
int
-TestRunner::ApplyCommandLine(int argc, char const* const* argv)
+TestRunner::ApplyCommandLine(int Argc, char const* const* Argv)
{
- m_Impl->Session.applyCommandLine(argc, argv);
+ m_Impl->Session.applyCommandLine(Argc, Argv);
- for (int i = 1; i < argc; ++i)
+ for (int i = 1; i < Argc; ++i)
{
- if (argv[i] == "--debug"sv)
+ if (Argv[i] == "--debug"sv)
{
- zen::logging::SetLogLevel(zen::logging::level::Debug);
+ zen::logging::SetLogLevel(zen::logging::Debug);
}
- else if (argv[i] == "--verbose"sv)
+ else if (Argv[i] == "--verbose"sv)
{
- zen::logging::SetLogLevel(zen::logging::level::Trace);
+ zen::logging::SetLogLevel(zen::logging::Trace);
}
}
@@ -101,6 +160,57 @@ TestRunner::Run()
return m_Impl->Session.run();
}
+int
+RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink)())
+{
+# if ZEN_PLATFORM_WINDOWS
+ setlocale(LC_ALL, "en_us.UTF8");
+# endif
+
+ ForceLink();
+
+# if ZEN_PLATFORM_LINUX
+ zen::IgnoreChildSignals();
+# endif
+
+# if ZEN_WITH_TRACE
+ zen::TraceInit(ExecutableName);
+ zen::TraceOptions TraceCommandlineOptions;
+ if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
+ {
+ TraceConfigure(TraceCommandlineOptions);
+ }
+# endif
+
+ zen::logging::InitializeLogging();
+ zen::MaximizeOpenFileCount();
+
+ TestRunner Runner;
+
+ // Derive default suite filter from ExecutableName: "zencore-test" -> "core.*"
+ if (ExecutableName)
+ {
+ std::string_view Name = ExecutableName;
+ if (Name.starts_with("zen"))
+ {
+ Name.remove_prefix(3);
+ }
+ if (Name.ends_with("-test"))
+ {
+ Name.remove_suffix(5);
+ }
+ if (!Name.empty())
+ {
+ std::string Filter(Name);
+ Filter += ".*";
+ Runner.SetDefaultSuiteFilter(Filter.c_str());
+ }
+ }
+
+ Runner.ApplyCommandLine(Argc, Argv);
+ return Runner.Run();
+}
+
} // namespace zen::testing
#endif // ZEN_WITH_TESTS
diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp
index 5bc2841ae..0cd3f8121 100644
--- a/src/zencore/testutils.cpp
+++ b/src/zencore/testutils.cpp
@@ -46,7 +46,7 @@ ScopedTemporaryDirectory::~ScopedTemporaryDirectory()
IoBuffer
CreateRandomBlob(uint64_t Size)
{
- static FastRandom Rand{.Seed = 0x7CEBF54E45B9F5D1};
+ thread_local FastRandom Rand{.Seed = 0x7CEBF54E45B9F5D1};
return CreateRandomBlob(Rand, Size);
};
diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp
index 9e3486e49..54459cbaa 100644
--- a/src/zencore/thread.cpp
+++ b/src/zencore/thread.cpp
@@ -133,7 +133,10 @@ SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName)
#elif ZEN_PLATFORM_MAC
pthread_setname_np(ThreadNameZ.c_str());
#else
- pthread_setname_np(pthread_self(), ThreadNameZ.c_str());
+ // Linux pthread_setname_np has a 16-byte limit (15 chars + NUL)
+ StringBuilder<16> LinuxThreadName;
+ LinuxThreadName << LimitedThreadName.substr(0, 15);
+ pthread_setname_np(pthread_self(), LinuxThreadName.c_str());
#endif
} // namespace zen
@@ -233,12 +236,15 @@ Event::Close()
#else
std::atomic_thread_fence(std::memory_order_acquire);
auto* Inner = (EventInner*)m_EventHandle.load();
+ if (Inner)
{
- std::unique_lock Lock(Inner->Mutex);
- Inner->bSet.store(true);
- m_EventHandle = nullptr;
+ {
+ std::unique_lock Lock(Inner->Mutex);
+ Inner->bSet.store(true);
+ m_EventHandle = nullptr;
+ }
+ delete Inner;
}
- delete Inner;
#endif
}
@@ -351,7 +357,7 @@ NamedEvent::NamedEvent(std::string_view EventName)
intptr_t Packed;
Packed = intptr_t(Sem) << 32;
Packed |= intptr_t(Fd) & 0xffff'ffff;
- m_EventHandle = (void*)Packed;
+ m_EventHandle = (void*)Packed;
#endif
ZEN_ASSERT(m_EventHandle != nullptr);
}
@@ -372,7 +378,9 @@ NamedEvent::Close()
#if ZEN_PLATFORM_WINDOWS
CloseHandle(m_EventHandle);
#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
- int Fd = int(intptr_t(m_EventHandle.load()) & 0xffff'ffff);
+ const intptr_t Handle = intptr_t(m_EventHandle.load());
+ const int Fd = int(Handle & 0xffff'ffff);
+ const int Sem = int(Handle >> 32);
if (flock(Fd, LOCK_EX | LOCK_NB) == 0)
{
@@ -388,11 +396,10 @@ NamedEvent::Close()
}
flock(Fd, LOCK_UN | LOCK_NB);
- close(Fd);
-
- int Sem = int(intptr_t(m_EventHandle.load()) >> 32);
semctl(Sem, 0, IPC_RMID);
}
+
+ close(Fd);
#endif
m_EventHandle = nullptr;
@@ -481,9 +488,12 @@ NamedMutex::~NamedMutex()
CloseHandle(m_MutexHandle);
}
#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
- int Inner = int(intptr_t(m_MutexHandle));
- flock(Inner, LOCK_UN);
- close(Inner);
+ if (m_MutexHandle)
+ {
+ int Inner = int(intptr_t(m_MutexHandle));
+ flock(Inner, LOCK_UN);
+ close(Inner);
+ }
#endif
}
@@ -516,7 +526,6 @@ NamedMutex::Create(std::string_view MutexName)
if (flock(Inner, LOCK_EX) != 0)
{
close(Inner);
- Inner = 0;
return false;
}
@@ -583,6 +592,11 @@ GetCurrentThreadId()
void
Sleep(int ms)
{
+ if (ms <= 0)
+ {
+ return;
+ }
+
#if ZEN_PLATFORM_WINDOWS
::Sleep(ms);
#else
diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp
index 87035554f..7c195e69f 100644
--- a/src/zencore/trace.cpp
+++ b/src/zencore/trace.cpp
@@ -10,7 +10,16 @@
# define TRACE_IMPLEMENT 1
# undef _WINSOCK_DEPRECATED_NO_WARNINGS
+// GCC false positives in thirdparty trace.h (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137)
+# if ZEN_COMPILER_GCC
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wstringop-overread"
+# pragma GCC diagnostic ignored "-Wdangling-pointer"
+# endif
# include <zencore/trace.h>
+# if ZEN_COMPILER_GCC
+# pragma GCC diagnostic pop
+# endif
# include <zencore/memory/fmalloc.h>
# include <zencore/memory/memorytrace.h>
@@ -165,10 +174,17 @@ GetTraceOptionsFromCommandline(TraceOptions& OutOptions)
auto MatchesArg = [](std::string_view Option, std::string_view Arg) -> std::optional<std::string_view> {
if (Arg.starts_with(Option))
{
- std::string_view::value_type DelimChar = Arg[Option.length()];
- if (DelimChar == ' ' || DelimChar == '=')
+ if (Arg.length() > Option.length())
+ {
+ std::string_view::value_type DelimChar = Arg[Option.length()];
+ if (DelimChar == ' ' || DelimChar == '=')
+ {
+ return Arg.substr(Option.size() + 1);
+ }
+ }
+ else
{
- return Arg.substr(Option.size() + 1);
+ return ""sv;
}
}
return {};
diff --git a/src/zencore/uid.cpp b/src/zencore/uid.cpp
index d7636f2ad..971683721 100644
--- a/src/zencore/uid.cpp
+++ b/src/zencore/uid.cpp
@@ -156,6 +156,8 @@ Oid::FromMemory(const void* Ptr)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("core.uid");
+
TEST_CASE("Oid")
{
SUBCASE("Basic")
@@ -185,6 +187,8 @@ TEST_CASE("Oid")
}
}
+TEST_SUITE_END();
+
void
uid_forcelink()
{
diff --git a/src/zencore/windows.cpp b/src/zencore/windows.cpp
index d02fcd35e..87f854b90 100644
--- a/src/zencore/windows.cpp
+++ b/src/zencore/windows.cpp
@@ -12,14 +12,12 @@ namespace zen::windows {
bool
IsRunningOnWine()
{
- HMODULE NtDll = GetModuleHandleA("ntdll.dll");
+ static bool s_Result = [] {
+ HMODULE NtDll = GetModuleHandleA("ntdll.dll");
+ return NtDll && !!GetProcAddress(NtDll, "wine_get_version");
+ }();
- if (NtDll)
- {
- return !!GetProcAddress(NtDll, "wine_get_version");
- }
-
- return false;
+ return s_Result;
}
FileMapping::FileMapping(_In_ FileMapping& orig)
diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp
index cb84bbe06..1cb338c66 100644
--- a/src/zencore/workthreadpool.cpp
+++ b/src/zencore/workthreadpool.cpp
@@ -354,6 +354,8 @@ workthreadpool_forcelink()
using namespace std::literals;
+TEST_SUITE_BEGIN("core.workthreadpool");
+
TEST_CASE("threadpool.basic")
{
WorkerThreadPool Threadpool{1};
@@ -368,6 +370,8 @@ TEST_CASE("threadpool.basic")
CHECK_THROWS(FutureThrow.get());
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua
index a3fd4dacb..171f4c533 100644
--- a/src/zencore/xmake.lua
+++ b/src/zencore/xmake.lua
@@ -15,6 +15,7 @@ target('zencore')
set_configdir("include/zencore")
add_files("**.cpp")
add_files("trace.cpp", {unity_ignored = true })
+ add_files("testing.cpp", {unity_ignored = true })
if has_config("zenrpmalloc") then
add_deps("rpmalloc")
@@ -25,7 +26,6 @@ target('zencore')
end
add_deps("zenbase")
- add_deps("spdlog")
add_deps("utfcpp")
add_deps("oodle")
add_deps("blake3")
@@ -33,8 +33,6 @@ target('zencore')
add_deps("timesinceprocessstart")
add_deps("doctest")
add_deps("fmt")
- add_deps("ryml")
-
add_packages("json11")
if is_plat("linux", "macosx") then
diff --git a/src/zencore/xxhash.cpp b/src/zencore/xxhash.cpp
index 6d1050531..88a48dd68 100644
--- a/src/zencore/xxhash.cpp
+++ b/src/zencore/xxhash.cpp
@@ -59,6 +59,8 @@ xxhash_forcelink()
{
}
+TEST_SUITE_BEGIN("core.xxhash");
+
TEST_CASE("XXH3_128")
{
using namespace std::literals;
@@ -96,6 +98,8 @@ TEST_CASE("XXH3_128")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp
index 4ff79edc7..8c29a8962 100644
--- a/src/zencore/zencore.cpp
+++ b/src/zencore/zencore.cpp
@@ -147,7 +147,7 @@ AssertImpl::OnAssert(const char* Filename, int LineNumber, const char* FunctionN
Message.push_back('\0');
// We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log
- ZEN_LOG(Log(), zen::logging::level::Err, "{}", Message.data());
+ ZEN_LOG(Log(), zen::logging::Err, "{}", Message.data());
zen::logging::FlushLogging();
}
@@ -285,7 +285,7 @@ zencore_forcelinktests()
namespace zen {
-TEST_SUITE_BEGIN("core.assert");
+TEST_SUITE_BEGIN("core.zencore");
TEST_CASE("Assert.Default")
{
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp
new file mode 100644
index 000000000..819b2d0cb
--- /dev/null
+++ b/src/zenhorde/hordeagent.cpp
@@ -0,0 +1,297 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordeagent.h"
+#include "hordetransportaes.h"
+
+#include <zencore/basicfile.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/trace.h>
+
+#include <cstring>
+#include <unordered_map>
+
+namespace zen::horde {
+
+HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
+{
+ ZEN_TRACE_CPU("HordeAgent::Connect");
+
+ auto Transport = std::make_unique<TcpComputeTransport>(Info);
+ if (!Transport->IsValid())
+ {
+ ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
+ return;
+ }
+
+ // The 64-byte nonce is always sent unencrypted as the first thing on the wire.
+ // The Horde agent uses this to identify which lease this connection belongs to.
+ Transport->Send(Info.Nonce, sizeof(Info.Nonce));
+
+ std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
+ if (Info.EncryptionMode == Encryption::AES)
+ {
+ FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
+ if (!FinalTransport->IsValid())
+ {
+ ZEN_WARN("failed to create AES transport");
+ return;
+ }
+ }
+
+ // Create multiplexed socket and channels
+ m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));
+
+ // Channel 0 is the agent control channel (handles Attach/Fork handshake).
+ // Channel 100 is the child I/O channel (handles file upload and remote execution).
+ Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
+ Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);
+
+ if (!AgentComputeChannel || !ChildComputeChannel)
+ {
+ ZEN_WARN("failed to create compute channels");
+ return;
+ }
+
+ m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
+ m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));
+
+ m_IsValid = true;
+}
+
+HordeAgent::~HordeAgent()
+{
+ CloseConnection();
+}
+
+bool
+HordeAgent::BeginCommunication()
+{
+ ZEN_TRACE_CPU("HordeAgent::BeginCommunication");
+
+ if (!m_IsValid)
+ {
+ return false;
+ }
+
+ // Start the send/recv pump threads
+ m_Socket->StartCommunication();
+
+ // Wait for Attach on agent channel
+ AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
+ if (Type == AgentMessageType::None)
+ {
+ ZEN_WARN("timed out waiting for Attach on agent channel");
+ return false;
+ }
+ if (Type != AgentMessageType::Attach)
+ {
+ ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
+ return false;
+ }
+
+ // Fork tells the remote agent to create child channel 100 with a 4MB buffer.
+ // After this, the agent will send an Attach on the child channel.
+ m_AgentChannel->Fork(100, 4 * 1024 * 1024);
+
+ // Wait for Attach on child channel
+ Type = m_ChildChannel->ReadResponse(5000);
+ if (Type == AgentMessageType::None)
+ {
+ ZEN_WARN("timed out waiting for Attach on child channel");
+ return false;
+ }
+ if (Type != AgentMessageType::Attach)
+ {
+ ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
+ return false;
+ }
+
+ return true;
+}
+
+bool
+HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
+{
+ ZEN_TRACE_CPU("HordeAgent::UploadBinaries");
+
+ m_ChildChannel->UploadFiles("", BundleLocator.c_str());
+
+ std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;
+
+ auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
+ std::string Key(Locator);
+
+ if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
+ {
+ return It->second.get();
+ }
+
+ const std::filesystem::path Path = BundleDir / (Key + ".blob");
+ std::error_code Ec;
+ auto File = std::make_unique<BasicFile>();
+ File->Open(Path, BasicFile::Mode::kRead, Ec);
+
+ if (Ec)
+ {
+ ZEN_ERROR("cannot read blob file: '{}'", Path);
+ return nullptr;
+ }
+
+ BasicFile* Ptr = File.get();
+ BlobFiles.emplace(std::move(Key), std::move(File));
+ return Ptr;
+ };
+
+ // The upload protocol is request-driven: we send WriteFiles, then the remote agent
+ // sends ReadBlob requests for each blob it needs. We respond with Blob data until
+ // the agent sends WriteFilesResponse indicating the upload is complete.
+ constexpr int32_t ReadResponseTimeoutMs = 1000;
+
+ for (;;)
+ {
+ bool TimedOut = false;
+
+ if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
+ {
+ if (TimedOut)
+ {
+ continue;
+ }
+ // End of stream - check if it was a successful upload
+ if (Type == AgentMessageType::WriteFilesResponse)
+ {
+ return true;
+ }
+ else if (Type == AgentMessageType::Exception)
+ {
+ ExceptionInfo Ex;
+ m_ChildChannel->ReadException(Ex);
+ ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
+ }
+ else
+ {
+ ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
+ }
+ return false;
+ }
+
+ BlobRequest Req;
+ m_ChildChannel->ReadBlobRequest(Req);
+
+ BasicFile* File = FindOrOpenBlob(Req.Locator);
+ if (!File)
+ {
+ return false;
+ }
+
+ // Read from offset to end of file
+ const uint64_t TotalSize = File->FileSize();
+ const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
+ if (Offset >= TotalSize)
+ {
+ ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
+ m_ChildChannel->Blob(nullptr, 0);
+ continue;
+ }
+
+ const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
+ m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
+ }
+}
+
+void
+HordeAgent::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ bool UseWine)
+{
+ ZEN_TRACE_CPU("HordeAgent::Execute");
+ m_ChildChannel
+ ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+}
+
+bool
+HordeAgent::Poll(bool LogOutput)
+{
+ constexpr int32_t ReadResponseTimeoutMs = 100;
+ AgentMessageType Type;
+
+ while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
+ {
+ switch (Type)
+ {
+ case AgentMessageType::ExecuteOutput:
+ {
+ if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
+ {
+ const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
+ size_t ResponseSize = m_ChildChannel->GetResponseSize();
+
+ // Trim trailing newlines
+ while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
+ {
+ --ResponseSize;
+ }
+
+ if (ResponseSize > 0)
+ {
+ const std::string_view Output(ResponseData, ResponseSize);
+ ZEN_INFO("[remote] {}", Output);
+ }
+ }
+ break;
+ }
+
+ case AgentMessageType::ExecuteResult:
+ {
+ if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
+ {
+ int32_t ExitCode;
+ memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
+ ZEN_INFO("remote process exited with code {}", ExitCode);
+ }
+ m_IsValid = false;
+ return false;
+ }
+
+ case AgentMessageType::Exception:
+ {
+ ExceptionInfo Ex;
+ m_ChildChannel->ReadException(Ex);
+ ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
+ m_HasErrors = true;
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+
+ return m_IsValid && !m_HasErrors;
+}
+
+void
+HordeAgent::CloseConnection()
+{
+ if (m_ChildChannel)
+ {
+ m_ChildChannel->Close();
+ }
+ if (m_AgentChannel)
+ {
+ m_AgentChannel->Close();
+ }
+}
+
+bool
+HordeAgent::IsValid() const
+{
+ return m_IsValid && !m_HasErrors;
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h
new file mode 100644
index 000000000..e0ae89ead
--- /dev/null
+++ b/src/zenhorde/hordeagent.h
@@ -0,0 +1,77 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "hordeagentmessage.h"
+#include "hordecomputesocket.h"
+
+#include <zenhorde/hordeclient.h>
+
+#include <zencore/logbase.h>
+
+#include <filesystem>
+#include <memory>
+#include <string>
+
+namespace zen::horde {
+
+/** Manages the lifecycle of a single Horde compute agent.
+ *
+ * Handles the full connection sequence for one provisioned machine:
+ * 1. Connect via TCP transport (with optional AES encryption wrapping)
+ * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100)
+ * 3. Perform the Attach/Fork handshake to establish the child channel
+ * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol
+ * 5. Execute zenserver remotely via ExecuteV2
+ * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code)
+ */
+class HordeAgent
+{
+public:
+ explicit HordeAgent(const MachineInfo& Info);
+ ~HordeAgent();
+
+ HordeAgent(const HordeAgent&) = delete;
+ HordeAgent& operator=(const HordeAgent&) = delete;
+
+ /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel).
+ * Returns false if the handshake times out or receives an unexpected message. */
+ bool BeginCommunication();
+
+ /** Upload binary files to the remote agent.
+ * @param BundleDir Directory containing .blob files.
+ * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */
+ bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator);
+
+ /** Execute a command on the remote machine. */
+ void Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir = nullptr,
+ const char* const* EnvVars = nullptr,
+ size_t NumEnvVars = 0,
+ bool UseWine = false);
+
+ /** Poll for output and results. Returns true if the agent is still running.
+ * When LogOutput is true, remote stdout is logged via ZEN_INFO. */
+ bool Poll(bool LogOutput = true);
+
+ void CloseConnection();
+ bool IsValid() const;
+
+ const MachineInfo& GetMachineInfo() const { return m_MachineInfo; }
+
+private:
+ LoggerRef Log() { return m_Log; }
+
+ std::unique_ptr<ComputeSocket> m_Socket;
+ std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control
+ std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O
+
+ LoggerRef m_Log;
+ bool m_IsValid = false;
+ bool m_HasErrors = false;
+ MachineInfo m_MachineInfo;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
new file mode 100644
index 000000000..998134a96
--- /dev/null
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -0,0 +1,340 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordeagentmessage.h"
+
+#include <zencore/intmath.h>
+
+#include <cassert>
+#include <cstring>
+
+namespace zen::horde {
+
+AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
+{
+}
+
+AgentMessageChannel::~AgentMessageChannel() = default;
+
+void
+AgentMessageChannel::Close()
+{
+ CreateMessage(AgentMessageType::None, 0);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Ping()
+{
+ CreateMessage(AgentMessageType::Ping, 0);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+{
+ CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(ChannelId);
+ WriteInt32(BufferSize);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Attach()
+{
+ CreateMessage(AgentMessageType::Attach, 0);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+{
+ CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Path);
+ WriteString(Locator);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
+{
+ size_t RequiredSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
+ {
+ RequiredSize += strlen(Args[i]) + 10;
+ }
+ if (WorkingDir)
+ {
+ RequiredSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ RequiredSize += strlen(EnvVars[i]) + 20;
+ }
+
+ CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
+ WriteString(Exe);
+
+ WriteUnsignedVarInt(NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
+ {
+ WriteString(Args[i]);
+ }
+
+ WriteOptionalString(WorkingDir);
+
+ // ExecuteV2 protocol requires env vars as separate key/value pairs.
+ // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
+ WriteUnsignedVarInt(NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ assert(Eq != nullptr);
+
+ WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Eq + 1);
+ }
+ }
+
+ WriteInt32(static_cast<int>(Flags));
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+{
+ // Blob responses are chunked to fit within the compute buffer's chunk size.
+ // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
+ const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ {
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
+
+ CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(static_cast<int>(ChunkOffset));
+ WriteInt32(static_cast<int>(Length));
+ WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
+ FlushMessage();
+
+ ChunkOffset += ChunkLength;
+ }
+}
+
+AgentMessageType
+AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+{
+ // Deferred advance: the previous response's buffer is only released when the next
+ // ReadResponse is called. This allows callers to read response data between calls
+ // without copying, since the pointer comes directly from the ring buffer.
+ if (m_ResponseData)
+ {
+ m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
+ m_ResponseData = nullptr;
+ m_ResponseLength = 0;
+ }
+
+ const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
+ if (!Header)
+ {
+ return AgentMessageType::None;
+ }
+
+ uint32_t Length;
+ memcpy(&Length, Header + 1, sizeof(uint32_t));
+
+ Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
+ if (!Header)
+ {
+ return AgentMessageType::None;
+ }
+
+ m_ResponseType = static_cast<AgentMessageType>(Header[0]);
+ m_ResponseData = Header + MessageHeaderLength;
+ m_ResponseLength = Length;
+
+ return m_ResponseType;
+}
+
+void
+AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+{
+ assert(m_ResponseType == AgentMessageType::Exception);
+ const uint8_t* Pos = m_ResponseData;
+ Ex.Message = ReadString(&Pos);
+ Ex.Description = ReadString(&Pos);
+}
+
+int
+AgentMessageChannel::ReadExecuteResult()
+{
+ assert(m_ResponseType == AgentMessageType::ExecuteResult);
+ const uint8_t* Pos = m_ResponseData;
+ return ReadInt32(&Pos);
+}
+
+void
+AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+{
+ assert(m_ResponseType == AgentMessageType::ReadBlob);
+ const uint8_t* Pos = m_ResponseData;
+ Req.Locator = ReadString(&Pos);
+ Req.Offset = ReadUnsignedVarInt(&Pos);
+ Req.Length = ReadUnsignedVarInt(&Pos);
+}
+
+void
+AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+{
+ m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
+ m_RequestData[0] = static_cast<uint8_t>(Type);
+ m_MaxRequestSize = MaxLength;
+ m_RequestSize = 0;
+}
+
+void
+AgentMessageChannel::FlushMessage()
+{
+ const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
+ memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
+ m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
+ m_RequestSize = 0;
+ m_MaxRequestSize = 0;
+ m_RequestData = nullptr;
+}
+
+void
+AgentMessageChannel::WriteInt32(int Value)
+{
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+}
+
+int
+AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+{
+ int Value;
+ memcpy(&Value, *Pos, sizeof(int));
+ *Pos += sizeof(int);
+ return Value;
+}
+
+void
+AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+{
+ assert(m_RequestSize + Length <= m_MaxRequestSize);
+ memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
+ m_RequestSize += Length;
+}
+
+const uint8_t*
+AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+{
+ const uint8_t* Data = *Pos;
+ *Pos += Length;
+ return Data;
+}
+
+size_t
+AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+{
+ if (Value == 0)
+ {
+ return 1;
+ }
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
+}
+
+void
+AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
+{
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+
+ uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
+ for (size_t i = 1; i < ByteCount; ++i)
+ {
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
+ }
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+
+ m_RequestSize += ByteCount;
+}
+
+size_t
+AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+{
+ const uint8_t* Data = *Pos;
+ const uint8_t FirstByte = Data[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
+ {
+ Value <<= 8;
+ Value |= Data[i];
+ }
+
+ *Pos += NumBytes;
+ return Value;
+}
+
+size_t
+AgentMessageChannel::MeasureString(const char* Text) const
+{
+ const size_t Length = strlen(Text);
+ return MeasureUnsignedVarInt(Length) + Length;
+}
+
+void
+AgentMessageChannel::WriteString(const char* Text)
+{
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Length);
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+}
+
+void
+AgentMessageChannel::WriteString(std::string_view Text)
+{
+ WriteUnsignedVarInt(Text.size());
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+}
+
+std::string_view
+AgentMessageChannel::ReadString(const uint8_t** Pos)
+{
+ const size_t Length = ReadUnsignedVarInt(Pos);
+ const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
+ return std::string_view(Start, Length);
+}
+
+void
+AgentMessageChannel::WriteOptionalString(const char* Text)
+{
+ // Optional strings use length+1 encoding: 0 means null/absent,
+ // N>0 means a string of length N-1 follows. This matches the UE
+ // FAgentMessageChannel serialization convention.
+ if (!Text)
+ {
+ WriteUnsignedVarInt(0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Length + 1);
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ }
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h
new file mode 100644
index 000000000..38c4375fd
--- /dev/null
+++ b/src/zenhorde/hordeagentmessage.h
@@ -0,0 +1,161 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/zenbase.h>
+
+#include "hordecomputechannel.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen::horde {
+
+/** Agent message types matching the UE EAgentMessageType byte values.
+ * These are the message opcodes exchanged over the agent/child channels. */
+enum class AgentMessageType : uint8_t
+{
+ None = 0x00,
+ Ping = 0x01,
+ Exception = 0x02,
+ Fork = 0x03,
+ Attach = 0x04,
+ WriteFiles = 0x10,
+ WriteFilesResponse = 0x11,
+ DeleteFiles = 0x12,
+ ExecuteV2 = 0x22,
+ ExecuteOutput = 0x17,
+ ExecuteResult = 0x18,
+ ReadBlob = 0x20,
+ ReadBlobResponse = 0x21,
+};
+
+/** Flags for the ExecuteV2 message. */
+enum class ExecuteProcessFlags : uint8_t
+{
+ None = 0,
+ UseWine = 1, ///< Run the executable under Wine on Linux agents
+};
+
+/** Parsed exception information from an Exception message. */
+struct ExceptionInfo
+{
+ std::string_view Message;
+ std::string_view Description;
+};
+
+/** Parsed blob read request from a ReadBlob message. */
+struct BlobRequest
+{
+ std::string_view Locator;
+ size_t Offset = 0;
+ size_t Length = 0;
+};
+
+/** Channel for sending and receiving agent messages over a ComputeChannel.
+ *
+ * Implements the Horde agent message protocol, matching the UE
+ * FAgentMessageChannel serialization format exactly. Messages are framed as
+ * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8;
+ * integers use variable-length encoding.
+ *
+ * The protocol has two directions:
+ * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob
+ * - Responses (remote -> initiator): ReadResponse returns the type, then call the
+ * appropriate Read* method to parse the payload.
+ */
+class AgentMessageChannel
+{
+public:
+ explicit AgentMessageChannel(Ref<ComputeChannel> Channel);
+ ~AgentMessageChannel();
+
+ AgentMessageChannel(const AgentMessageChannel&) = delete;
+ AgentMessageChannel& operator=(const AgentMessageChannel&) = delete;
+
+ // --- Requests (Initiator -> Remote) ---
+
+ /** Close the channel. */
+ void Close();
+
+ /** Send a keepalive ping. */
+ void Ping();
+
+ /** Fork communication to a new channel with the given ID and buffer size. */
+ void Fork(int ChannelId, int BufferSize);
+
+ /** Send an attach request (used during channel setup handshake). */
+ void Attach();
+
+ /** Request the remote agent to write files from the given bundle locator. */
+ void UploadFiles(const char* Path, const char* Locator);
+
+ /** Execute a process on the remote machine. */
+ void Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags = ExecuteProcessFlags::None);
+
+ /** Send blob data in response to a ReadBlob request. */
+ void Blob(const uint8_t* Data, size_t Length);
+
+ // --- Responses (Remote -> Initiator) ---
+
+ /** Read the next response message. Returns the message type, or None on timeout.
+ * After this returns, use GetResponseData()/GetResponseSize() or the typed
+ * Read* methods to access the payload. */
+ AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr);
+
+ const void* GetResponseData() const { return m_ResponseData; }
+ size_t GetResponseSize() const { return m_ResponseLength; }
+
+ /** Parse an Exception response payload. */
+ void ReadException(ExceptionInfo& Ex);
+
+ /** Parse an ExecuteResult response payload. Returns the exit code. */
+ int ReadExecuteResult();
+
+ /** Parse a ReadBlob response payload into a BlobRequest. */
+ void ReadBlobRequest(BlobRequest& Req);
+
+private:
+ static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)]
+
+ Ref<ComputeChannel> m_Channel;
+
+ uint8_t* m_RequestData = nullptr;
+ size_t m_RequestSize = 0;
+ size_t m_MaxRequestSize = 0;
+
+ AgentMessageType m_ResponseType = AgentMessageType::None;
+ const uint8_t* m_ResponseData = nullptr;
+ size_t m_ResponseLength = 0;
+
+ void CreateMessage(AgentMessageType Type, size_t MaxLength);
+ void FlushMessage();
+
+ void WriteInt32(int Value);
+ static int ReadInt32(const uint8_t** Pos);
+
+ void WriteFixedLengthBytes(const uint8_t* Data, size_t Length);
+ static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length);
+
+ static size_t MeasureUnsignedVarInt(size_t Value);
+ void WriteUnsignedVarInt(size_t Value);
+ static size_t ReadUnsignedVarInt(const uint8_t** Pos);
+
+ size_t MeasureString(const char* Text) const;
+ void WriteString(const char* Text);
+ void WriteString(std::string_view Text);
+ static std::string_view ReadString(const uint8_t** Pos);
+
+ void WriteOptionalString(const char* Text);
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp
new file mode 100644
index 000000000..d3974bc28
--- /dev/null
+++ b/src/zenhorde/hordebundle.cpp
@@ -0,0 +1,619 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordebundle.h"
+
+#include <zencore/basicfile.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/intmath.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/process.h>
+#include <zencore/trace.h>
+
+#include <algorithm>
+#include <chrono>
+#include <cstring>
+
+namespace zen::horde {
+
+static LoggerRef
+Log()
+{
+ static auto s_Logger = zen::logging::Get("horde.bundle");
+ return s_Logger;
+}
+
+static constexpr uint8_t PacketSignature[3] = {'U', 'B', 'N'};
+static constexpr uint8_t PacketVersion = 5;
+static constexpr int32_t CurrentPacketBaseIdx = -2;
+static constexpr int ImportBias = 3;
+static constexpr uint32_t ChunkSize = 64 * 1024; // 64KB fixed chunks
+static constexpr uint32_t LargeFileThreshold = 128 * 1024; // 128KB
+
+// BlobType: 20 bytes each = FGuid (16 bytes, 4x uint32 LE) + Version (int32 LE)
+// Values from UE SDK: GUIDs stored as 4 uint32 LE values.
+
+// ChunkLeaf v1: {0xB27AFB68, 0x4A4B9E20, 0x8A78D8A4, 0x39D49840}
+static constexpr uint8_t BlobType_ChunkLeafV1[20] = {0x68, 0xFB, 0x7A, 0xB2, 0x20, 0x9E, 0x4B, 0x4A, 0xA4, 0xD8,
+ 0x78, 0x8A, 0x40, 0x98, 0xD4, 0x39, 0x01, 0x00, 0x00, 0x00}; // version 1
+
+// ChunkInterior v2: {0xF4DEDDBC, 0x4C7A70CB, 0x11F04783, 0xB9CDCCAF}
+static constexpr uint8_t BlobType_ChunkInteriorV2[20] = {0xBC, 0xDD, 0xDE, 0xF4, 0xCB, 0x70, 0x7A, 0x4C, 0x83, 0x47,
+ 0xF0, 0x11, 0xAF, 0xCC, 0xCD, 0xB9, 0x02, 0x00, 0x00, 0x00}; // version 2
+
+// Directory v1: {0x0714EC11, 0x4D07291A, 0x8AE77F86, 0x799980D6}
+static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1A, 0x29, 0x07, 0x4D, 0x86, 0x7F,
+ 0xE7, 0x8A, 0xD6, 0x80, 0x99, 0x79, 0x01, 0x00, 0x00, 0x00}; // version 1
+
+static constexpr size_t BlobTypeSize = 20;
+
+// ─── VarInt helpers (UE format) ─────────────────────────────────────────────
+
+static size_t
+MeasureVarInt(size_t Value)
+{
+ if (Value == 0)
+ {
+ return 1;
+ }
+ return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1;
+}
+
+static void
+WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value)
+{
+ const size_t ByteCount = MeasureVarInt(Value);
+ const size_t Offset = Buffer.size();
+ Buffer.resize(Offset + ByteCount);
+
+ uint8_t* Output = Buffer.data() + Offset;
+ for (size_t i = 1; i < ByteCount; ++i)
+ {
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
+ }
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+}
+
+// ─── Binary helpers ─────────────────────────────────────────────────────────
+
+static void
+WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value)
+{
+ uint8_t Bytes[4];
+ memcpy(Bytes, &Value, 4);
+ Buffer.insert(Buffer.end(), Bytes, Bytes + 4);
+}
+
+static void
+WriteByte(std::vector<uint8_t>& Buffer, uint8_t Value)
+{
+ Buffer.push_back(Value);
+}
+
+static void
+WriteBytes(std::vector<uint8_t>& Buffer, const void* Data, size_t Size)
+{
+ auto* Ptr = static_cast<const uint8_t*>(Data);
+ Buffer.insert(Buffer.end(), Ptr, Ptr + Size);
+}
+
+static void
+WriteString(std::vector<uint8_t>& Buffer, std::string_view Str)
+{
+ WriteVarInt(Buffer, Str.size());
+ WriteBytes(Buffer, Str.data(), Str.size());
+}
+
+static void
+AlignTo4(std::vector<uint8_t>& Buffer)
+{
+ while (Buffer.size() % 4 != 0)
+ {
+ Buffer.push_back(0);
+ }
+}
+
+static void
+PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value)
+{
+ memcpy(Buffer.data() + Offset, &Value, 4);
+}
+
+// ─── Packet builder ─────────────────────────────────────────────────────────
+
+// Builds a single uncompressed Horde V2 packet. Layout:
+// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header)
+// [TypeTableOffset(4) + ImportTableOffset(4) + ExportTableOffset(4)] 12 bytes
+// [Export data...]
+// [Type table: count(4) + count * 20 bytes]
+// [Import table: count(4) + (count+1) offset entries(4 each) + import data]
+// [Export table: count(4) + (count+1) offset entries(4 each)]
+//
+// ALL offsets are absolute from byte 0 of the full packet (including the 8-byte header).
+// PacketLength in the header = total packet size including the 8-byte header.
+
+struct PacketBuilder
+{
+ std::vector<uint8_t> Data;
+ std::vector<int32_t> ExportOffsets; // Absolute byte offset of each export from byte 0
+
+ // Type table: unique 20-byte BlobType entries
+ std::vector<const uint8_t*> Types;
+
+ // Import table entries: (baseIdx, fragment)
+ struct ImportEntry
+ {
+ int32_t BaseIdx;
+ std::string Fragment;
+ };
+ std::vector<ImportEntry> Imports;
+
+ // Current export's start offset (absolute from byte 0)
+ size_t CurrentExportStart = 0;
+
+ PacketBuilder()
+ {
+ // Reserve packet header (8 bytes) + table offsets (12 bytes) = 20 bytes
+ Data.resize(20, 0);
+
+ // Write signature
+ Data[0] = PacketSignature[0];
+ Data[1] = PacketSignature[1];
+ Data[2] = PacketSignature[2];
+ Data[3] = PacketVersion;
+ // PacketLength, TypeTableOffset, ImportTableOffset, ExportTableOffset
+ // will be patched in Finish()
+ }
+
+ int AddType(const uint8_t* BlobType)
+ {
+ for (size_t i = 0; i < Types.size(); ++i)
+ {
+ if (memcmp(Types[i], BlobType, BlobTypeSize) == 0)
+ {
+ return static_cast<int>(i);
+ }
+ }
+ Types.push_back(BlobType);
+ return static_cast<int>(Types.size() - 1);
+ }
+
+ int AddImport(int32_t BaseIdx, std::string Fragment)
+ {
+ Imports.push_back({BaseIdx, std::move(Fragment)});
+ return static_cast<int>(Imports.size() - 1);
+ }
+
+ void BeginExport()
+ {
+ AlignTo4(Data);
+ CurrentExportStart = Data.size();
+ // Reserve space for payload length
+ WriteLE32(Data, 0);
+ }
+
+ // Write raw payload data into the current export
+ void WritePayload(const void* Payload, size_t Size) { WriteBytes(Data, Payload, Size); }
+
+ // Complete the current export: patches payload length, writes type+imports metadata
+ int CompleteExport(const uint8_t* BlobType, const std::vector<int>& ImportIndices)
+ {
+ const int ExportIndex = static_cast<int>(ExportOffsets.size());
+
+ // Patch payload length (does not include the 4-byte length field itself)
+ const size_t PayloadStart = CurrentExportStart + 4;
+ const int32_t PayloadLen = static_cast<int32_t>(Data.size() - PayloadStart);
+ PatchLE32(Data, CurrentExportStart, PayloadLen);
+
+ // Write type index (varint)
+ const int TypeIdx = AddType(BlobType);
+ WriteVarInt(Data, static_cast<size_t>(TypeIdx));
+
+ // Write import count + indices
+ WriteVarInt(Data, ImportIndices.size());
+ for (int Idx : ImportIndices)
+ {
+ WriteVarInt(Data, static_cast<size_t>(Idx));
+ }
+
+ // Record export offset (absolute from byte 0)
+ ExportOffsets.push_back(static_cast<int32_t>(CurrentExportStart));
+
+ return ExportIndex;
+ }
+
+ // Finalize the packet: write type/import/export tables, patch header.
+ std::vector<uint8_t> Finish()
+ {
+ AlignTo4(Data);
+
+ // ── Type table: count(int32) + count * BlobTypeSize bytes ──
+ const int32_t TypeTableOffset = static_cast<int32_t>(Data.size());
+ WriteLE32(Data, static_cast<int32_t>(Types.size()));
+ for (const uint8_t* TypeEntry : Types)
+ {
+ WriteBytes(Data, TypeEntry, BlobTypeSize);
+ }
+
+ // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ──
+ const int32_t ImportTableOffset = static_cast<int32_t>(Data.size());
+ const int32_t ImportCount = static_cast<int32_t>(Imports.size());
+ WriteLE32(Data, ImportCount);
+
+ // Reserve space for (count+1) offset entries — will be patched below
+ const size_t ImportOffsetsStart = Data.size();
+ for (int32_t i = 0; i <= ImportCount; ++i)
+ {
+ WriteLE32(Data, 0); // placeholder
+ }
+
+ // Write import data and record offsets
+ for (int32_t i = 0; i < ImportCount; ++i)
+ {
+ // Record absolute offset of this import's data
+ PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(i) * 4, static_cast<int32_t>(Data.size()));
+
+ ImportEntry& Imp = Imports[static_cast<size_t>(i)];
+ // BaseIdx encoded as unsigned VarInt with bias: VarInt(BaseIdx + ImportBias)
+ const size_t EncodedBaseIdx = static_cast<size_t>(static_cast<int64_t>(Imp.BaseIdx) + ImportBias);
+ WriteVarInt(Data, EncodedBaseIdx);
+ // Fragment: raw UTF-8 bytes, NO length prefix (length determined by offset table)
+ WriteBytes(Data, Imp.Fragment.data(), Imp.Fragment.size());
+ }
+
+ // Sentinel offset (points past the last import's data)
+ PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size()));
+
+ // ── Export table: count(int32) + (count+1) offsets(int32 each) ──
+ const int32_t ExportTableOffset = static_cast<int32_t>(Data.size());
+ const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size());
+ WriteLE32(Data, ExportCount);
+
+ for (int32_t Off : ExportOffsets)
+ {
+ WriteLE32(Data, Off);
+ }
+ // Sentinel: points to the start of the type table (end of export data region)
+ WriteLE32(Data, TypeTableOffset);
+
+ // ── Patch header ──
+ // PacketLength = total packet size including the 8-byte header
+ const int32_t PacketLength = static_cast<int32_t>(Data.size());
+ PatchLE32(Data, 4, PacketLength);
+ PatchLE32(Data, 8, TypeTableOffset);
+ PatchLE32(Data, 12, ImportTableOffset);
+ PatchLE32(Data, 16, ExportTableOffset);
+
+ return std::move(Data);
+ }
+};
+
+// ─── Encoded packet wrapper ─────────────────────────────────────────────────
+
+// Wraps an uncompressed packet with the encoded header:
+// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes
+// [DecompressedLength(4)] 4 bytes
+// [CompressionFormat(1): 0=None] 1 byte
+// [PacketData...]
+//
+// HeaderLength = total encoded packet size INCLUDING the 8-byte outer header.
+
+static std::vector<uint8_t>
+EncodePacket(std::vector<uint8_t> UncompressedPacket)
+{
+ const int32_t DecompressedLen = static_cast<int32_t>(UncompressedPacket.size());
+ // HeaderLength includes the 8-byte outer signature header itself
+ const int32_t HeaderLength = 8 + 4 + 1 + DecompressedLen;
+
+ std::vector<uint8_t> Encoded;
+ Encoded.reserve(static_cast<size_t>(HeaderLength));
+
+ // Outer signature: 'U','B','N', version=5, HeaderLength (LE int32)
+ WriteByte(Encoded, PacketSignature[0]); // 'U'
+ WriteByte(Encoded, PacketSignature[1]); // 'B'
+ WriteByte(Encoded, PacketSignature[2]); // 'N'
+ WriteByte(Encoded, PacketVersion); // 5
+ WriteLE32(Encoded, HeaderLength);
+
+ // Decompressed length + compression format
+ WriteLE32(Encoded, DecompressedLen);
+ WriteByte(Encoded, 0); // CompressionFormat::None
+
+ // Packet data
+ WriteBytes(Encoded, UncompressedPacket.data(), UncompressedPacket.size());
+
+ return Encoded;
+}
+
+// ─── Bundle blob name generation ────────────────────────────────────────────
+
+static std::string
+GenerateBlobName()
+{
+ static std::atomic<uint32_t> s_Counter{0};
+
+ const int Pid = GetCurrentProcessId();
+
+ auto Now = std::chrono::steady_clock::now().time_since_epoch();
+ auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count();
+
+ ExtendableStringBuilder<64> Name;
+ Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1);
+ return std::string(Name.ToView());
+}
+
+// ─── File info for bundling ─────────────────────────────────────────────────
+
+struct FileInfo
+{
+ std::filesystem::path Path;
+ std::string Name; // Filename only (for directory entry)
+ uint64_t FileSize;
+ IoHash ContentHash; // IoHash of file content
+ BLAKE3 StreamHash; // Full BLAKE3 for stream hash
+ int DirectoryExportImportIndex; // Import index referencing this file's root export
+ IoHash RootExportHash; // IoHash of the root export for this file
+};
+
+// ─── CreateBundle implementation ────────────────────────────────────────────
+
+bool
+BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult)
+{
+ ZEN_TRACE_CPU("BundleCreator::CreateBundle");
+
+ std::error_code Ec;
+
+ // Collect files that exist
+ std::vector<FileInfo> ValidFiles;
+ for (const BundleFile& F : Files)
+ {
+ if (!std::filesystem::exists(F.Path, Ec))
+ {
+ if (F.Optional)
+ {
+ continue;
+ }
+ ZEN_ERROR("required bundle file does not exist: {}", F.Path.string());
+ return false;
+ }
+ FileInfo Info;
+ Info.Path = F.Path;
+ Info.Name = F.Path.filename().string();
+ Info.FileSize = std::filesystem::file_size(F.Path, Ec);
+ if (Ec)
+ {
+ ZEN_ERROR("failed to get file size: {}", F.Path.string());
+ return false;
+ }
+ ValidFiles.push_back(std::move(Info));
+ }
+
+ if (ValidFiles.empty())
+ {
+ ZEN_ERROR("no valid files to bundle");
+ return false;
+ }
+
+ std::filesystem::create_directories(OutputDir, Ec);
+ if (Ec)
+ {
+ ZEN_ERROR("failed to create output directory: {}", OutputDir.string());
+ return false;
+ }
+
+ const std::string BlobName = GenerateBlobName();
+ PacketBuilder Packet;
+
+ // Process each file: create chunk exports
+ for (FileInfo& Info : ValidFiles)
+ {
+ BasicFile File;
+ File.Open(Info.Path, BasicFile::Mode::kRead, Ec);
+ if (Ec)
+ {
+ ZEN_ERROR("failed to open file: {}", Info.Path.string());
+ return false;
+ }
+
+ // Compute stream hash (full BLAKE3) and content hash (IoHash) while reading
+ BLAKE3Stream StreamHasher;
+ IoHashStream ContentHasher;
+
+ if (Info.FileSize <= LargeFileThreshold)
+ {
+ // Small file: single chunk leaf export
+ IoBuffer Content = File.ReadAll();
+ const auto* Data = static_cast<const uint8_t*>(Content.GetData());
+ const size_t Size = Content.GetSize();
+
+ StreamHasher.Append(Data, Size);
+ ContentHasher.Append(Data, Size);
+
+ Packet.BeginExport();
+ Packet.WritePayload(Data, Size);
+
+ const IoHash ChunkHash = IoHash::HashBuffer(Data, Size);
+ const int ExportIndex = Packet.CompleteExport(BlobType_ChunkLeafV1, {});
+ Info.RootExportHash = ChunkHash;
+ Info.ContentHash = ContentHasher.GetHash();
+ Info.StreamHash = StreamHasher.GetHash();
+
+ // Add import for this file's root export (references export within same packet)
+ ExtendableStringBuilder<32> Fragment;
+ Fragment << "exp=" << ExportIndex;
+ Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView()));
+ }
+ else
+ {
+ // Large file: split into fixed 64KB chunks, then create interior node
+ std::vector<int> ChunkExportIndices;
+ std::vector<IoHash> ChunkHashes;
+
+ uint64_t Remaining = Info.FileSize;
+ uint64_t Offset = 0;
+
+ while (Remaining > 0)
+ {
+ const uint64_t ReadSize = std::min(static_cast<uint64_t>(ChunkSize), Remaining);
+ IoBuffer Chunk = File.ReadRange(Offset, ReadSize);
+ const auto* Data = static_cast<const uint8_t*>(Chunk.GetData());
+ const size_t Size = Chunk.GetSize();
+
+ StreamHasher.Append(Data, Size);
+ ContentHasher.Append(Data, Size);
+
+ Packet.BeginExport();
+ Packet.WritePayload(Data, Size);
+
+ const IoHash ChunkHash = IoHash::HashBuffer(Data, Size);
+ const int ExpIdx = Packet.CompleteExport(BlobType_ChunkLeafV1, {});
+
+ ChunkExportIndices.push_back(ExpIdx);
+ ChunkHashes.push_back(ChunkHash);
+
+ Offset += ReadSize;
+ Remaining -= ReadSize;
+ }
+
+ Info.ContentHash = ContentHasher.GetHash();
+ Info.StreamHash = StreamHasher.GetHash();
+
+ // Create interior node referencing all chunk leaves
+ // Interior payload: for each child: [IoHash(20)][node_type=1(1)] + imports
+ std::vector<int> InteriorImports;
+ for (size_t i = 0; i < ChunkExportIndices.size(); ++i)
+ {
+ ExtendableStringBuilder<32> Fragment;
+ Fragment << "exp=" << ChunkExportIndices[i];
+ const int ImportIdx = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView()));
+ InteriorImports.push_back(ImportIdx);
+ }
+
+ Packet.BeginExport();
+
+ // Write interior payload: [hash(20)][type(1)] per child
+ for (size_t i = 0; i < ChunkHashes.size(); ++i)
+ {
+ Packet.WritePayload(ChunkHashes[i].Hash, sizeof(IoHash));
+ const uint8_t NodeType = 1; // ChunkNode type
+ Packet.WritePayload(&NodeType, 1);
+ }
+
+ // Hash the interior payload to get the interior node hash
+ const IoHash InteriorHash = IoHash::HashBuffer(Packet.Data.data() + (Packet.CurrentExportStart + 4),
+ Packet.Data.size() - (Packet.CurrentExportStart + 4));
+
+ const int InteriorExportIndex = Packet.CompleteExport(BlobType_ChunkInteriorV2, InteriorImports);
+
+ Info.RootExportHash = InteriorHash;
+
+ // Add import for directory to reference this interior node
+ ExtendableStringBuilder<32> Fragment;
+ Fragment << "exp=" << InteriorExportIndex;
+ Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView()));
+ }
+ }
+
+ // Create directory node export
+ // Payload: [flags(varint=0)] [file_count(varint)] [file_entries...] [dir_count(varint=0)]
+ // FileEntry: [import(varint)] [IoHash(20)] [name(string)] [flags(varint)] [length(varint)] [IoHash_stream(20)]
+
+ Packet.BeginExport();
+
+ // Build directory payload into a temporary buffer, then write it
+ std::vector<uint8_t> DirPayload;
+ WriteVarInt(DirPayload, 0); // flags
+ WriteVarInt(DirPayload, ValidFiles.size()); // file_count
+
+ std::vector<int> DirImports;
+ for (size_t i = 0; i < ValidFiles.size(); ++i)
+ {
+ FileInfo& Info = ValidFiles[i];
+ DirImports.push_back(Info.DirectoryExportImportIndex);
+
+ // IoHash of target (20 bytes) — import is consumed sequentially from the
+ // export's import list by ReadBlobRef, not encoded in the payload
+ WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash));
+ // name (string)
+ WriteString(DirPayload, Info.Name);
+ // flags (varint): 1 = Executable
+ WriteVarInt(DirPayload, 1);
+ // length (varint)
+ WriteVarInt(DirPayload, static_cast<size_t>(Info.FileSize));
+ // stream hash: IoHash from full BLAKE3, truncated to 20 bytes
+ const IoHash StreamIoHash = IoHash::FromBLAKE3(Info.StreamHash);
+ WriteBytes(DirPayload, StreamIoHash.Hash, sizeof(IoHash));
+ }
+
+ WriteVarInt(DirPayload, 0); // dir_count
+
+ Packet.WritePayload(DirPayload.data(), DirPayload.size());
+ const int DirExportIndex = Packet.CompleteExport(BlobType_DirectoryV1, DirImports);
+
+ // Finalize packet and encode
+ std::vector<uint8_t> UncompressedPacket = Packet.Finish();
+ std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket));
+
+ // Write .blob file
+ const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob");
+ {
+ BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec);
+ if (Ec)
+ {
+ ZEN_ERROR("failed to create blob file: {}", BlobFilePath.string());
+ return false;
+ }
+ BlobFile.Write(EncodedPacket.data(), EncodedPacket.size(), 0);
+ }
+
+ // Build locator: <blob_name>#pkt=0,<encoded_len>&exp=<dir_export_index>
+ ExtendableStringBuilder<256> Locator;
+ Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex;
+ const std::string LocatorStr(Locator.ToView());
+
+ // Write .ref file (use first file's name as the ref base)
+ const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref");
+ {
+ BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec);
+ if (Ec)
+ {
+ ZEN_ERROR("failed to create ref file: {}", RefFilePath.string());
+ return false;
+ }
+ RefFile.Write(LocatorStr.data(), LocatorStr.size(), 0);
+ }
+
+ OutResult.Locator = LocatorStr;
+ OutResult.BundleDir = OutputDir;
+
+ ZEN_INFO("created V2 bundle: blob={}.blob locator={} files={}", BlobName, LocatorStr, ValidFiles.size());
+ return true;
+}
+
+bool
+BundleCreator::ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator)
+{
+ BasicFile File;
+ std::error_code Ec;
+ File.Open(RefFile, BasicFile::Mode::kRead, Ec);
+ if (Ec)
+ {
+ return false;
+ }
+
+ IoBuffer Content = File.ReadAll();
+ OutLocator.assign(static_cast<const char*>(Content.GetData()), Content.GetSize());
+
+ // Strip trailing whitespace/newlines
+ while (!OutLocator.empty() && (OutLocator.back() == '\n' || OutLocator.back() == '\r' || OutLocator.back() == '\0'))
+ {
+ OutLocator.pop_back();
+ }
+
+ return !OutLocator.empty();
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordebundle.h b/src/zenhorde/hordebundle.h
new file mode 100644
index 000000000..052f60435
--- /dev/null
+++ b/src/zenhorde/hordebundle.h
@@ -0,0 +1,49 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <filesystem>
+#include <string>
+#include <vector>
+
+namespace zen::horde {
+
+/** Describes a file to include in a Horde bundle. */
+struct BundleFile
+{
+ std::filesystem::path Path; ///< Local file path
+ bool Optional; ///< If true, skip without error if missing
+};
+
+/** Result of a successful bundle creation. */
+struct BundleResult
+{
+ std::string Locator; ///< Root directory locator for WriteFiles
+ std::filesystem::path BundleDir; ///< Directory containing .blob files
+};
+
+/** Creates Horde V2 bundles from local files for upload to remote agents.
+ *
+ * Produces a proper Horde storage V2 bundle containing:
+ * - Chunk leaf exports for file data (split into 64KB chunks for large files)
+ * - Optional interior chunk nodes referencing leaf chunks
+ * - A directory node listing all bundled files with metadata
+ *
+ * The bundle is written as a single .blob file with a corresponding .ref file
+ * containing the locator string. The locator format is:
+ * <blob_name>#pkt=0,<encoded_len>&exp=<directory_export_index>
+ */
+struct BundleCreator
+{
+ /** Create a V2 bundle from one or more input files.
+ * @param Files Files to include in the bundle.
+ * @param OutputDir Directory where .blob and .ref files will be written.
+ * @param OutResult Receives the locator and output directory on success.
+ * @return True on success. */
+ static bool CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult);
+
+ /** Read a locator string from a .ref file. Strips trailing whitespace/newlines. */
+ static bool ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator);
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp
new file mode 100644
index 000000000..fb981f0ba
--- /dev/null
+++ b/src/zenhorde/hordeclient.cpp
@@ -0,0 +1,382 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/memoryview.h>
+#include <zencore/trace.h>
+#include <zenhorde/hordeclient.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::horde {
+
+HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client"))
+{
+}
+
+HordeClient::~HordeClient() = default;
+
+bool
+HordeClient::Initialize()
+{
+ ZEN_TRACE_CPU("HordeClient::Initialize");
+
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.http";
+ Settings.ConnectTimeout = std::chrono::milliseconds{10000};
+ Settings.Timeout = std::chrono::milliseconds{60000};
+ Settings.RetryCount = 1;
+ Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests};
+
+ if (!m_Config.AuthToken.empty())
+ {
+ Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken {
+ HttpClientAccessToken Token;
+ Token.Value = token;
+ Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24};
+ return Token;
+ };
+ }
+
+ m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings);
+
+ if (!m_Config.AuthToken.empty())
+ {
+ if (!m_Http->Authenticate())
+ {
+ ZEN_WARN("failed to authenticate with Horde server");
+ return false;
+ }
+ }
+
+ return true;
+}
+
+std::string
+HordeClient::BuildRequestBody() const
+{
+ json11::Json::object Requirements;
+
+ if (m_Config.Mode == ConnectionMode::Direct && !m_Config.Pool.empty())
+ {
+ Requirements["pool"] = m_Config.Pool;
+ }
+
+ std::string Condition;
+#if ZEN_PLATFORM_WINDOWS
+ ExtendableStringBuilder<256> CondBuf;
+ CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')";
+ Condition = std::string(CondBuf);
+#elif ZEN_PLATFORM_MAC
+ Condition = "OSFamily == 'MacOS'";
+#else
+ Condition = "OSFamily == 'Linux'";
+#endif
+
+ if (!m_Config.Condition.empty())
+ {
+ Condition += " ";
+ Condition += m_Config.Condition;
+ }
+
+ Requirements["condition"] = Condition;
+ Requirements["exclusive"] = true;
+
+ json11::Json::object Connection;
+ Connection["modePreference"] = ToString(m_Config.Mode);
+
+ if (m_Config.EncryptionMode != Encryption::None)
+ {
+ Connection["encryption"] = ToString(m_Config.EncryptionMode);
+ }
+
+ // Request configured zen service port to be forwarded. The Horde agent will map this
+ // to a local port on the provisioned machine and report it back in the response.
+ json11::Json::object PortsObj;
+ PortsObj["ZenPort"] = json11::Json(m_Config.ZenServicePort);
+ Connection["ports"] = PortsObj;
+
+ json11::Json::object Root;
+ Root["requirements"] = Requirements;
+ Root["connection"] = Connection;
+
+ return json11::Json(Root).dump();
+}
+
+bool
+HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster)
+{
+ ZEN_TRACE_CPU("HordeClient::ResolveCluster");
+
+ const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON);
+
+ const HttpClient::Response Response = m_Http->Post("api/v2/compute/_cluster", Payload);
+
+ if (Response.Error)
+ {
+ ZEN_WARN("cluster resolution failed: {}", Response.Error->ErrorMessage);
+ return false;
+ }
+
+ const int StatusCode = static_cast<int>(Response.StatusCode);
+
+ if (StatusCode == 503 || StatusCode == 429)
+ {
+ ZEN_DEBUG("cluster resolution returned HTTP/{}: no resources", StatusCode);
+ return false;
+ }
+
+ if (StatusCode == 401)
+ {
+ ZEN_WARN("cluster resolution returned HTTP/401: token expired");
+ return false;
+ }
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("cluster resolution failed with HTTP/{}", StatusCode);
+ return false;
+ }
+
+ const std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("invalid JSON response for cluster resolution: {}", Err);
+ return false;
+ }
+
+ const json11::Json ClusterIdVal = Json["clusterId"];
+ if (!ClusterIdVal.is_string() || ClusterIdVal.string_value().empty())
+ {
+ ZEN_WARN("missing 'clusterId' in cluster resolution response");
+ return false;
+ }
+
+ OutCluster.ClusterId = ClusterIdVal.string_value();
+ return true;
+}
+
+bool
+HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize)
+{
+ if (Hex.size() != OutSize * 2)
+ {
+ return false;
+ }
+
+ for (size_t i = 0; i < OutSize; ++i)
+ {
+ auto HexToByte = [](char c) -> int {
+ if (c >= '0' && c <= '9')
+ return c - '0';
+ if (c >= 'a' && c <= 'f')
+ return c - 'a' + 10;
+ if (c >= 'A' && c <= 'F')
+ return c - 'A' + 10;
+ return -1;
+ };
+
+ const int Hi = HexToByte(Hex[i * 2]);
+ const int Lo = HexToByte(Hex[i * 2 + 1]);
+ if (Hi < 0 || Lo < 0)
+ {
+ return false;
+ }
+ Out[i] = static_cast<uint8_t>((Hi << 4) | Lo);
+ }
+
+ return true;
+}
+
+bool
+HordeClient::RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine)
+{
+ ZEN_TRACE_CPU("HordeClient::RequestMachine");
+
+ ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str());
+
+ ExtendableStringBuilder<128> ResourcePath;
+ ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str());
+
+ const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON);
+ const HttpClient::Response Response = m_Http->Post(ResourcePath.ToView(), Payload);
+
+ // Reset output to invalid state
+ OutMachine = {};
+ OutMachine.Port = 0xFFFF;
+
+ if (Response.Error)
+ {
+ ZEN_WARN("machine request failed: {}", Response.Error->ErrorMessage);
+ return false;
+ }
+
+ const int StatusCode = static_cast<int>(Response.StatusCode);
+
+ if (StatusCode == 404 || StatusCode == 503 || StatusCode == 429)
+ {
+ ZEN_DEBUG("machine request returned HTTP/{}: no resources", StatusCode);
+ return false;
+ }
+
+ if (StatusCode == 401)
+ {
+ ZEN_WARN("machine request returned HTTP/401: token expired");
+ return false;
+ }
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("machine request failed with HTTP/{}", StatusCode);
+ return false;
+ }
+
+ const std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("invalid JSON response for machine request: {}", Err);
+ return false;
+ }
+
+ // Required fields
+ const json11::Json NonceVal = Json["nonce"];
+ const json11::Json IpVal = Json["ip"];
+ const json11::Json PortVal = Json["port"];
+
+ if (!NonceVal.is_string() || !IpVal.is_string() || !PortVal.is_number())
+ {
+ ZEN_WARN("missing 'nonce', 'ip', or 'port' in machine response");
+ return false;
+ }
+
+ OutMachine.Ip = IpVal.string_value();
+ OutMachine.Port = static_cast<uint16_t>(PortVal.int_value());
+
+ if (!ParseHexBytes(NonceVal.string_value(), OutMachine.Nonce, NonceSize))
+ {
+ ZEN_WARN("invalid nonce hex string in machine response");
+ return false;
+ }
+
+ if (const json11::Json PortsVal = Json["ports"]; PortsVal.is_object())
+ {
+ for (const auto& [Key, Val] : PortsVal.object_items())
+ {
+ PortInfo Info;
+ if (Val["port"].is_number())
+ {
+ Info.Port = static_cast<uint16_t>(Val["port"].int_value());
+ }
+ if (Val["agentPort"].is_number())
+ {
+ Info.AgentPort = static_cast<uint16_t>(Val["agentPort"].int_value());
+ }
+ OutMachine.Ports[Key] = Info;
+ }
+ }
+
+ if (const json11::Json ConnectionModeVal = Json["connectionMode"]; ConnectionModeVal.is_string())
+ {
+ if (FromString(OutMachine.Mode, ConnectionModeVal.string_value()))
+ {
+ if (const json11::Json ConnectionAddressVal = Json["connectionAddress"]; ConnectionAddressVal.is_string())
+ {
+ OutMachine.ConnectionAddress = ConnectionAddressVal.string_value();
+ }
+ }
+ }
+
+ // Properties are a flat string array of "Key=Value" pairs describing the machine.
+ // We extract OS family and core counts for sizing decisions. If neither core count
+ // is available, we fall back to 16 as a conservative default.
+ uint16_t LogicalCores = 0;
+ uint16_t PhysicalCores = 0;
+
+ if (const json11::Json PropertiesVal = Json["properties"]; PropertiesVal.is_array())
+ {
+ for (const json11::Json& PropVal : PropertiesVal.array_items())
+ {
+ if (!PropVal.is_string())
+ {
+ continue;
+ }
+
+ const std::string Prop = PropVal.string_value();
+ if (Prop.starts_with("OSFamily="))
+ {
+ if (Prop.substr(9) == "Windows")
+ {
+ OutMachine.IsWindows = true;
+ }
+ }
+ else if (Prop.starts_with("LogicalCores="))
+ {
+ LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13));
+ }
+ else if (Prop.starts_with("PhysicalCores="))
+ {
+ PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14));
+ }
+ }
+ }
+
+ if (LogicalCores > 0)
+ {
+ OutMachine.LogicalCores = LogicalCores;
+ }
+ else if (PhysicalCores > 0)
+ {
+ OutMachine.LogicalCores = PhysicalCores * 2;
+ }
+ else
+ {
+ OutMachine.LogicalCores = 16;
+ }
+
+ if (const json11::Json EncryptionVal = Json["encryption"]; EncryptionVal.is_string())
+ {
+ if (FromString(OutMachine.EncryptionMode, EncryptionVal.string_value()))
+ {
+ if (OutMachine.EncryptionMode == Encryption::AES)
+ {
+ const json11::Json KeyVal = Json["key"];
+ if (KeyVal.is_string() && !KeyVal.string_value().empty())
+ {
+ if (!ParseHexBytes(KeyVal.string_value(), OutMachine.Key, KeySize))
+ {
+ ZEN_WARN("invalid AES key in machine response");
+ }
+ }
+ else
+ {
+ ZEN_WARN("AES encryption requested but no key provided");
+ }
+ }
+ }
+ }
+
+ if (const json11::Json LeaseIdVal = Json["leaseId"]; LeaseIdVal.is_string())
+ {
+ OutMachine.LeaseId = LeaseIdVal.string_value();
+ }
+
+ ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}",
+ OutMachine.GetConnectionAddress(),
+ OutMachine.GetConnectionPort(),
+ OutMachine.LogicalCores,
+ OutMachine.LeaseId);
+
+ return true;
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp
new file mode 100644
index 000000000..0d032b5d5
--- /dev/null
+++ b/src/zenhorde/hordecomputebuffer.cpp
@@ -0,0 +1,454 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordecomputebuffer.h"
+
+#include <algorithm>
+#include <cassert>
+#include <chrono>
+#include <condition_variable>
+#include <cstring>
+
+namespace zen::horde {
+
+// Simplified ring buffer implementation for in-process use only.
+// Uses a single contiguous buffer with write/read cursors and
+// mutex+condvar for synchronization. This is simpler than the UE version
+// which uses lock-free atomics and shared memory, but sufficient for our
+// use case where we're the initiator side of the compute protocol.
+
+struct ComputeBuffer::Detail : TRefCounted<Detail>
+{
+ std::vector<uint8_t> Data;
+ size_t NumChunks = 0;
+ size_t ChunkLength = 0;
+
+ // Current write state
+ size_t WriteChunkIdx = 0;
+ size_t WriteOffset = 0;
+ bool WriteComplete = false;
+
+ // Current read state
+ size_t ReadChunkIdx = 0;
+ size_t ReadOffset = 0;
+ bool Detached = false;
+
+ // Per-chunk written length
+ std::vector<size_t> ChunkWrittenLength;
+ std::vector<bool> ChunkFinished; // Writer moved to next chunk
+
+ std::mutex Mutex;
+ std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes
+ std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space
+
+ bool HasWriter = false;
+ bool HasReader = false;
+
+ uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; }
+ const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; }
+};
+
+// ComputeBuffer
+
+ComputeBuffer::ComputeBuffer()
+{
+}
+ComputeBuffer::~ComputeBuffer()
+{
+}
+
+bool
+ComputeBuffer::CreateNew(const Params& InParams)
+{
+ auto* NewDetail = new Detail();
+ NewDetail->NumChunks = InParams.NumChunks;
+ NewDetail->ChunkLength = InParams.ChunkLength;
+ NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0);
+ NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0);
+ NewDetail->ChunkFinished.resize(InParams.NumChunks, false);
+
+ m_Detail = NewDetail;
+ return true;
+}
+
+void
+ComputeBuffer::Close()
+{
+ m_Detail = nullptr;
+}
+
+bool
+ComputeBuffer::IsValid() const
+{
+ return static_cast<bool>(m_Detail);
+}
+
+ComputeBufferReader
+ComputeBuffer::CreateReader()
+{
+ assert(m_Detail);
+ m_Detail->HasReader = true;
+ return ComputeBufferReader(m_Detail);
+}
+
+ComputeBufferWriter
+ComputeBuffer::CreateWriter()
+{
+ assert(m_Detail);
+ m_Detail->HasWriter = true;
+ return ComputeBufferWriter(m_Detail);
+}
+
+// ComputeBufferReader
+
+ComputeBufferReader::ComputeBufferReader()
+{
+}
+ComputeBufferReader::~ComputeBufferReader()
+{
+}
+
+ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default;
+ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default;
+ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default;
+ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default;
+
+ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
+{
+}
+
+void
+ComputeBufferReader::Close()
+{
+ m_Detail = nullptr;
+}
+
+void
+ComputeBufferReader::Detach()
+{
+ if (m_Detail)
+ {
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ m_Detail->Detached = true;
+ m_Detail->ReadCV.notify_all();
+ }
+}
+
+bool
+ComputeBufferReader::IsValid() const
+{
+ return static_cast<bool>(m_Detail);
+}
+
+bool
+ComputeBufferReader::IsComplete() const
+{
+ if (!m_Detail)
+ {
+ return true;
+ }
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ if (m_Detail->Detached)
+ {
+ return true;
+ }
+ return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx &&
+ m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx];
+}
+
+void
+ComputeBufferReader::AdvanceReadPosition(size_t Size)
+{
+ if (!m_Detail)
+ {
+ return;
+ }
+
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+
+ m_Detail->ReadOffset += Size;
+
+ // Check if we need to move to next chunk
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
+ {
+ const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
+ m_Detail->ReadChunkIdx = NextChunk;
+ m_Detail->ReadOffset = 0;
+ m_Detail->WriteCV.notify_all();
+ }
+
+ m_Detail->ReadCV.notify_all();
+}
+
+size_t
+ComputeBufferReader::GetMaxReadSize() const
+{
+ if (!m_Detail)
+ {
+ return 0;
+ }
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
+}
+
+const uint8_t*
+ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut)
+{
+ if (!m_Detail)
+ {
+ return nullptr;
+ }
+
+ std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
+
+ auto Predicate = [&]() -> bool {
+ if (m_Detail->Detached)
+ {
+ return true;
+ }
+
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
+
+ if (Available >= MinSize)
+ {
+ return true;
+ }
+
+ // If chunk is finished and we've read everything, try to move to next
+ if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
+ {
+ if (m_Detail->WriteComplete)
+ {
+ return true; // End of stream
+ }
+ // Move to next chunk
+ const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
+ m_Detail->ReadChunkIdx = NextChunk;
+ m_Detail->ReadOffset = 0;
+ m_Detail->WriteCV.notify_all();
+ return false; // Re-check with new chunk
+ }
+
+ if (m_Detail->WriteComplete)
+ {
+ return true; // End of stream
+ }
+
+ return false;
+ };
+
+ if (TimeoutMs < 0)
+ {
+ m_Detail->ReadCV.wait(Lock, Predicate);
+ }
+ else
+ {
+ if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
+ {
+ if (OutTimedOut)
+ {
+ *OutTimedOut = true;
+ }
+ return nullptr;
+ }
+ }
+
+ if (m_Detail->Detached)
+ {
+ return nullptr;
+ }
+
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
+
+ if (Available < MinSize)
+ {
+ return nullptr; // End of stream
+ }
+
+ return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset;
+}
+
+size_t
+ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut)
+{
+ const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut);
+ if (!Data)
+ {
+ return 0;
+ }
+
+ const size_t Available = GetMaxReadSize();
+ const size_t ToCopy = std::min(Available, MaxSize);
+ memcpy(Buffer, Data, ToCopy);
+ AdvanceReadPosition(ToCopy);
+ return ToCopy;
+}
+
+// ComputeBufferWriter
+
+ComputeBufferWriter::ComputeBufferWriter() = default;
+ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default;
+ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default;
+ComputeBufferWriter::~ComputeBufferWriter() = default;
+ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default;
+ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default;
+
+ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
+{
+}
+
+void
+ComputeBufferWriter::Close()
+{
+ if (m_Detail)
+ {
+ {
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ if (!m_Detail->WriteComplete)
+ {
+ m_Detail->WriteComplete = true;
+ m_Detail->ReadCV.notify_all();
+ }
+ }
+ m_Detail = nullptr;
+ }
+}
+
+bool
+ComputeBufferWriter::IsValid() const
+{
+ return static_cast<bool>(m_Detail);
+}
+
+void
+ComputeBufferWriter::MarkComplete()
+{
+ if (m_Detail)
+ {
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ m_Detail->WriteComplete = true;
+ m_Detail->ReadCV.notify_all();
+ }
+}
+
+void
+ComputeBufferWriter::AdvanceWritePosition(size_t Size)
+{
+ if (!m_Detail || Size == 0)
+ {
+ return;
+ }
+
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ const size_t WriteChunk = m_Detail->WriteChunkIdx;
+ m_Detail->ChunkWrittenLength[WriteChunk] += Size;
+ m_Detail->WriteOffset += Size;
+ m_Detail->ReadCV.notify_all();
+}
+
+size_t
+ComputeBufferWriter::GetMaxWriteSize() const
+{
+ if (!m_Detail)
+ {
+ return 0;
+ }
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ const size_t WriteChunk = m_Detail->WriteChunkIdx;
+ return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
+}
+
+size_t
+ComputeBufferWriter::GetChunkMaxLength() const
+{
+ if (!m_Detail)
+ {
+ return 0;
+ }
+ return m_Detail->ChunkLength;
+}
+
+size_t
+ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs)
+{
+ uint8_t* Dest = WaitToWrite(1, TimeoutMs);
+ if (!Dest)
+ {
+ return 0;
+ }
+
+ const size_t Available = GetMaxWriteSize();
+ const size_t ToCopy = std::min(Available, MaxSize);
+ memcpy(Dest, Buffer, ToCopy);
+ AdvanceWritePosition(ToCopy);
+ return ToCopy;
+}
+
+uint8_t*
+ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs)
+{
+ if (!m_Detail)
+ {
+ return nullptr;
+ }
+
+ std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
+
+ if (m_Detail->WriteComplete)
+ {
+ return nullptr;
+ }
+
+ const size_t WriteChunk = m_Detail->WriteChunkIdx;
+ const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
+
+ // If current chunk has enough space, return pointer
+ if (Available >= MinSize)
+ {
+ return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk];
+ }
+
+ // Current chunk is full - mark it as finished and move to next.
+ // The writer cannot advance until the reader has fully consumed the next chunk,
+ // preventing the writer from overwriting data the reader hasn't processed yet.
+ m_Detail->ChunkFinished[WriteChunk] = true;
+ m_Detail->ReadCV.notify_all();
+
+ const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks;
+
+ // Wait until reader has consumed the next chunk
+ auto Predicate = [&]() -> bool {
+ // Check if read has moved past this chunk
+ return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached;
+ };
+
+ if (TimeoutMs < 0)
+ {
+ m_Detail->WriteCV.wait(Lock, Predicate);
+ }
+ else
+ {
+ if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
+ {
+ return nullptr;
+ }
+ }
+
+ if (m_Detail->Detached)
+ {
+ return nullptr;
+ }
+
+ // Reset next chunk
+ m_Detail->ChunkWrittenLength[NextChunk] = 0;
+ m_Detail->ChunkFinished[NextChunk] = false;
+ m_Detail->WriteChunkIdx = NextChunk;
+ m_Detail->WriteOffset = 0;
+
+ return m_Detail->ChunkPtr(NextChunk);
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h
new file mode 100644
index 000000000..64ef91b7a
--- /dev/null
+++ b/src/zenhorde/hordecomputebuffer.h
@@ -0,0 +1,136 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/refcount.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <mutex>
+#include <vector>
+
+namespace zen::horde {
+
+class ComputeBufferReader;
+class ComputeBufferWriter;
+
+/** Simplified in-process ring buffer for the Horde compute protocol.
+ *
+ * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files,
+ * this implementation uses plain heap-allocated memory since we only need in-process
+ * communication between channel and transport threads. The buffer is divided into
+ * fixed-size chunks; readers and writers block when no space is available.
+ */
+class ComputeBuffer
+{
+public:
+ struct Params
+ {
+ size_t NumChunks = 2;
+ size_t ChunkLength = 512 * 1024;
+ };
+
+ ComputeBuffer();
+ ~ComputeBuffer();
+
+ ComputeBuffer(const ComputeBuffer&) = delete;
+ ComputeBuffer& operator=(const ComputeBuffer&) = delete;
+
+ bool CreateNew(const Params& InParams);
+ void Close();
+
+ bool IsValid() const;
+
+ ComputeBufferReader CreateReader();
+ ComputeBufferWriter CreateWriter();
+
+private:
+ struct Detail;
+ Ref<Detail> m_Detail;
+
+ friend class ComputeBufferReader;
+ friend class ComputeBufferWriter;
+};
+
+/** Read endpoint for a ComputeBuffer.
+ *
+ * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer
+ * directly into the buffer memory (zero-copy); the caller must call
+ * AdvanceReadPosition() after consuming the data.
+ */
+class ComputeBufferReader
+{
+public:
+ ComputeBufferReader();
+ ComputeBufferReader(const ComputeBufferReader&);
+ ComputeBufferReader(ComputeBufferReader&&) noexcept;
+ ~ComputeBufferReader();
+
+ ComputeBufferReader& operator=(const ComputeBufferReader&);
+ ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept;
+
+ void Close();
+ void Detach();
+ bool IsValid() const;
+ bool IsComplete() const;
+
+ void AdvanceReadPosition(size_t Size);
+ size_t GetMaxReadSize() const;
+
+ /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */
+ size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
+
+ /** Wait until at least MinSize bytes are available and return a direct pointer.
+ * Returns nullptr on timeout or if the writer has completed. */
+ const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
+
+private:
+ friend class ComputeBuffer;
+ explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail);
+
+ Ref<ComputeBuffer::Detail> m_Detail;
+};
+
+/** Write endpoint for a ComputeBuffer.
+ *
+ * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer
+ * directly into the buffer memory (zero-copy); the caller must call
+ * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal
+ * that no more data will be written.
+ */
+class ComputeBufferWriter
+{
+public:
+ ComputeBufferWriter();
+ ComputeBufferWriter(const ComputeBufferWriter&);
+ ComputeBufferWriter(ComputeBufferWriter&&) noexcept;
+ ~ComputeBufferWriter();
+
+ ComputeBufferWriter& operator=(const ComputeBufferWriter&);
+ ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept;
+
+ void Close();
+ bool IsValid() const;
+
+ /** Signal that no more data will be written. Unblocks any waiting readers. */
+ void MarkComplete();
+
+ void AdvanceWritePosition(size_t Size);
+ size_t GetMaxWriteSize() const;
+ size_t GetChunkMaxLength() const;
+
+ /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */
+ size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1);
+
+ /** Wait until at least MinSize bytes of write space are available and return a direct pointer.
+ * Returns nullptr on timeout. */
+ uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1);
+
+private:
+ friend class ComputeBuffer;
+ explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail);
+
+ Ref<ComputeBuffer::Detail> m_Detail;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp
new file mode 100644
index 000000000..ee2a6f327
--- /dev/null
+++ b/src/zenhorde/hordecomputechannel.cpp
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordecomputechannel.h"
+
+namespace zen::horde {
+
+ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter)
+: Reader(std::move(InReader))
+, Writer(std::move(InWriter))
+{
+}
+
+bool
+ComputeChannel::IsValid() const
+{
+ return Reader.IsValid() && Writer.IsValid();
+}
+
+size_t
+ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs)
+{
+ return Writer.Write(Data, Size, TimeoutMs);
+}
+
+size_t
+ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs)
+{
+ return Reader.Read(Data, Size, TimeoutMs);
+}
+
+void
+ComputeChannel::MarkComplete()
+{
+ Writer.MarkComplete();
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h
new file mode 100644
index 000000000..c1dff20e4
--- /dev/null
+++ b/src/zenhorde/hordecomputechannel.h
@@ -0,0 +1,32 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "hordecomputebuffer.h"
+
+namespace zen::horde {
+
+/** Bidirectional communication channel using a pair of compute buffers.
+ *
+ * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter
+ * (for sending data). Used by ComputeSocket to represent one logical channel
+ * within a multiplexed connection.
+ */
+class ComputeChannel : public TRefCounted<ComputeChannel>
+{
+public:
+ ComputeBufferReader Reader;
+ ComputeBufferWriter Writer;
+
+ ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter);
+
+ bool IsValid() const;
+
+ size_t Send(const void* Data, size_t Size, int TimeoutMs = -1);
+ size_t Recv(void* Data, size_t Size, int TimeoutMs = -1);
+
+ /** Signal that no more data will be sent on this channel. */
+ void MarkComplete();
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp
new file mode 100644
index 000000000..6ef67760c
--- /dev/null
+++ b/src/zenhorde/hordecomputesocket.cpp
@@ -0,0 +1,204 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordecomputesocket.h"
+
+#include <zencore/logging.h>
+
+namespace zen::horde {
+
+ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport)
+: m_Log(zen::logging::Get("horde.socket"))
+, m_Transport(std::move(Transport))
+{
+}
+
+ComputeSocket::~ComputeSocket()
+{
+ // Shutdown order matters: first stop the ping thread, then unblock send threads
+ // by detaching readers, then join send threads, and finally close the transport
+ // to unblock the recv thread (which is blocked on RecvMessage).
+ {
+ std::lock_guard<std::mutex> Lock(m_PingMutex);
+ m_PingShouldStop = true;
+ m_PingCV.notify_all();
+ }
+
+ for (auto& Reader : m_Readers)
+ {
+ Reader.Detach();
+ }
+
+ for (auto& [Id, Thread] : m_SendThreads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_Transport->Close();
+
+ if (m_RecvThread.joinable())
+ {
+ m_RecvThread.join();
+ }
+ if (m_PingThread.joinable())
+ {
+ m_PingThread.join();
+ }
+}
+
+Ref<ComputeChannel>
+ComputeSocket::CreateChannel(int ChannelId)
+{
+ ComputeBuffer::Params Params;
+
+ ComputeBuffer RecvBuffer;
+ if (!RecvBuffer.CreateNew(Params))
+ {
+ return {};
+ }
+
+ ComputeBuffer SendBuffer;
+ if (!SendBuffer.CreateNew(Params))
+ {
+ return {};
+ }
+
+ Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter()));
+
+ // Attach recv buffer writer (transport recv thread writes into this)
+ {
+ std::lock_guard<std::mutex> Lock(m_WritersMutex);
+ m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter());
+ }
+
+ // Attach send buffer reader (send thread reads from this)
+ {
+ ComputeBufferReader Reader = SendBuffer.CreateReader();
+ m_Readers.push_back(Reader);
+ m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader)));
+ }
+
+ return Channel;
+}
+
+void
+ComputeSocket::StartCommunication()
+{
+ m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this);
+ m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this);
+}
+
+void
+ComputeSocket::PingThreadProc()
+{
+ while (true)
+ {
+ {
+ std::unique_lock<std::mutex> Lock(m_PingMutex);
+ if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; }))
+ {
+ break;
+ }
+ }
+
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
+ FrameHeader Header;
+ Header.Channel = 0;
+ Header.Size = ControlPing;
+ m_Transport->SendMessage(&Header, sizeof(Header));
+ }
+}
+
+void
+ComputeSocket::RecvThreadProc()
+{
+ // Writers are cached locally to avoid taking m_WritersMutex on every frame.
+ // The shared m_Writers map is only accessed when a channel is seen for the first time.
+ std::unordered_map<int, ComputeBufferWriter> CachedWriters;
+
+ FrameHeader Header;
+ while (m_Transport->RecvMessage(&Header, sizeof(Header)))
+ {
+ if (Header.Size >= 0)
+ {
+ // Data frame
+ auto It = CachedWriters.find(Header.Channel);
+ if (It == CachedWriters.end())
+ {
+ std::lock_guard<std::mutex> Lock(m_WritersMutex);
+ auto WIt = m_Writers.find(Header.Channel);
+ if (WIt == m_Writers.end())
+ {
+ ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
+ // Skip the data
+ std::vector<uint8_t> Discard(Header.Size);
+ m_Transport->RecvMessage(Discard.data(), Header.Size);
+ continue;
+ }
+ It = CachedWriters.emplace(Header.Channel, WIt->second).first;
+ }
+
+ ComputeBufferWriter& Writer = It->second;
+ uint8_t* Dest = Writer.WaitToWrite(Header.Size);
+ if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size))
+ {
+ ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size);
+ return;
+ }
+ Writer.AdvanceWritePosition(Header.Size);
+ }
+ else if (Header.Size == ControlDetach)
+ {
+ // Detach the recv buffer for this channel
+ CachedWriters.erase(Header.Channel);
+
+ std::lock_guard<std::mutex> Lock(m_WritersMutex);
+ auto It = m_Writers.find(Header.Channel);
+ if (It != m_Writers.end())
+ {
+ It->second.MarkComplete();
+ m_Writers.erase(It);
+ }
+ }
+ else if (Header.Size == ControlPing)
+ {
+ // Ping response - ignore
+ }
+ else
+ {
+ ZEN_WARN("invalid frame header size: {}", Header.Size);
+ return;
+ }
+ }
+}
+
+void
+ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader)
+{
+ // Each channel has its own send thread. All send threads share m_SendMutex
+ // to serialize writes to the transport, since TCP requires atomic frame writes.
+ FrameHeader Header;
+ Header.Channel = Channel;
+
+ const uint8_t* Data;
+ while ((Data = Reader.WaitToRead(1)) != nullptr)
+ {
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
+
+ Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize());
+ m_Transport->SendMessage(&Header, sizeof(Header));
+ m_Transport->SendMessage(Data, Header.Size);
+ Reader.AdvanceReadPosition(Header.Size);
+ }
+
+ if (Reader.IsComplete())
+ {
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
+ Header.Size = ControlDetach;
+ m_Transport->SendMessage(&Header, sizeof(Header));
+ }
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h
new file mode 100644
index 000000000..0c3cb4195
--- /dev/null
+++ b/src/zenhorde/hordecomputesocket.h
@@ -0,0 +1,79 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "hordecomputebuffer.h"
+#include "hordecomputechannel.h"
+#include "hordetransport.h"
+
+#include <zencore/logbase.h>
+
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <thread>
+#include <unordered_map>
+#include <vector>
+
+namespace zen::horde {
+
+/** Multiplexed socket that routes data between multiple channels over a single transport.
+ *
+ * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers.
+ * A recv thread demultiplexes incoming frames to channel-specific buffers, while
+ * per-channel send threads multiplex outgoing data onto the shared transport.
+ *
+ * Wire format per frame: [channelId (4B)][size (4B)][data]
+ * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping.
+ */
+class ComputeSocket
+{
+public:
+ explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport);
+ ~ComputeSocket();
+
+ ComputeSocket(const ComputeSocket&) = delete;
+ ComputeSocket& operator=(const ComputeSocket&) = delete;
+
+ /** Create a channel with the given ID.
+ * Allocates anonymous in-process buffers and spawns a send thread for the channel. */
+ Ref<ComputeChannel> CreateChannel(int ChannelId);
+
+ /** Start the recv pump and ping threads. Must be called after all channels are created. */
+ void StartCommunication();
+
+private:
+ struct FrameHeader
+ {
+ int32_t Channel = 0;
+ int32_t Size = 0;
+ };
+
+ static constexpr int32_t ControlDetach = -2;
+ static constexpr int32_t ControlPing = -3;
+
+ LoggerRef Log() { return m_Log; }
+
+ void RecvThreadProc();
+ void SendThreadProc(int Channel, ComputeBufferReader Reader);
+ void PingThreadProc();
+
+ LoggerRef m_Log;
+ std::unique_ptr<ComputeTransport> m_Transport;
+ std::mutex m_SendMutex; ///< Serializes writes to the transport
+
+ std::mutex m_WritersMutex;
+ std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID
+
+ std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction
+ std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel
+
+ std::thread m_RecvThread;
+ std::thread m_PingThread;
+
+ bool m_PingShouldStop = false;
+ std::mutex m_PingMutex;
+ std::condition_variable m_PingCV;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp
new file mode 100644
index 000000000..2dca228d9
--- /dev/null
+++ b/src/zenhorde/hordeconfig.cpp
@@ -0,0 +1,89 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhorde/hordeconfig.h>
+
+namespace zen::horde {
+
+bool
+HordeConfig::Validate() const
+{
+ if (ServerUrl.empty())
+ {
+ return false;
+ }
+
+ // Relay mode implies AES encryption
+ if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES)
+ {
+ return false;
+ }
+
+ return true;
+}
+
+const char*
+ToString(ConnectionMode Mode)
+{
+ switch (Mode)
+ {
+ case ConnectionMode::Direct:
+ return "direct";
+ case ConnectionMode::Tunnel:
+ return "tunnel";
+ case ConnectionMode::Relay:
+ return "relay";
+ }
+ return "direct";
+}
+
+const char*
+ToString(Encryption Enc)
+{
+ switch (Enc)
+ {
+ case Encryption::None:
+ return "none";
+ case Encryption::AES:
+ return "aes";
+ }
+ return "none";
+}
+
+bool
+FromString(ConnectionMode& OutMode, std::string_view Str)
+{
+ if (Str == "direct")
+ {
+ OutMode = ConnectionMode::Direct;
+ return true;
+ }
+ if (Str == "tunnel")
+ {
+ OutMode = ConnectionMode::Tunnel;
+ return true;
+ }
+ if (Str == "relay")
+ {
+ OutMode = ConnectionMode::Relay;
+ return true;
+ }
+ return false;
+}
+
+bool
+FromString(Encryption& OutEnc, std::string_view Str)
+{
+ if (Str == "none")
+ {
+ OutEnc = Encryption::None;
+ return true;
+ }
+ if (Str == "aes")
+ {
+ OutEnc = Encryption::AES;
+ return true;
+ }
+ return false;
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp
new file mode 100644
index 000000000..f88c95da2
--- /dev/null
+++ b/src/zenhorde/hordeprovisioner.cpp
@@ -0,0 +1,367 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhorde/hordeclient.h>
+#include <zenhorde/hordeprovisioner.h>
+
+#include "hordeagent.h"
+#include "hordebundle.h"
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+#include <chrono>
+#include <thread>
+
+namespace zen::horde {
+
+struct HordeProvisioner::AgentWrapper
+{
+ std::thread Thread;
+ std::atomic<bool> ShouldExit{false};
+};
+
+HordeProvisioner::HordeProvisioner(const HordeConfig& Config,
+ const std::filesystem::path& BinariesPath,
+ const std::filesystem::path& WorkingDir,
+ std::string_view OrchestratorEndpoint)
+: m_Config(Config)
+, m_BinariesPath(BinariesPath)
+, m_WorkingDir(WorkingDir)
+, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_Log(zen::logging::Get("horde.provisioner"))
+{
+}
+
+HordeProvisioner::~HordeProvisioner()
+{
+ std::lock_guard<std::mutex> Lock(m_AgentsLock);
+ for (auto& Agent : m_Agents)
+ {
+ Agent->ShouldExit.store(true);
+ }
+ for (auto& Agent : m_Agents)
+ {
+ if (Agent->Thread.joinable())
+ {
+ Agent->Thread.join();
+ }
+ }
+}
+
+void
+HordeProvisioner::SetTargetCoreCount(uint32_t Count)
+{
+ ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount");
+
+ m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)));
+
+ while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load())
+ {
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
+ RequestAgent();
+ }
+
+ // Clean up finished agent threads
+ std::lock_guard<std::mutex> Lock(m_AgentsLock);
+ for (auto It = m_Agents.begin(); It != m_Agents.end();)
+ {
+ if ((*It)->ShouldExit.load())
+ {
+ if ((*It)->Thread.joinable())
+ {
+ (*It)->Thread.join();
+ }
+ It = m_Agents.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+}
+
+ProvisioningStats
+HordeProvisioner::GetStats() const
+{
+ ProvisioningStats Stats;
+ Stats.TargetCoreCount = m_TargetCoreCount.load();
+ Stats.EstimatedCoreCount = m_EstimatedCoreCount.load();
+ Stats.ActiveCoreCount = m_ActiveCoreCount.load();
+ Stats.AgentsActive = m_AgentsActive.load();
+ Stats.AgentsRequesting = m_AgentsRequesting.load();
+ return Stats;
+}
+
+uint32_t
+HordeProvisioner::GetAgentCount() const
+{
+ std::lock_guard<std::mutex> Lock(m_AgentsLock);
+ return static_cast<uint32_t>(m_Agents.size());
+}
+
+void
+HordeProvisioner::RequestAgent()
+{
+ m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
+
+ std::lock_guard<std::mutex> Lock(m_AgentsLock);
+
+ auto Wrapper = std::make_unique<AgentWrapper>();
+ AgentWrapper& Ref = *Wrapper;
+ Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); });
+
+ m_Agents.push_back(std::move(Wrapper));
+}
+
+void
+HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper)
+{
+ ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent");
+
+ static std::atomic<uint32_t> ThreadIndex{0};
+ const uint32_t CurrentIndex = ThreadIndex.fetch_add(1);
+
+ zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex));
+
+ std::unique_ptr<HordeAgent> Agent;
+ uint32_t MachineCoreCount = 0;
+
+ auto _ = MakeGuard([&] {
+ if (Agent)
+ {
+ Agent->CloseConnection();
+ }
+ Wrapper.ShouldExit.store(true);
+ });
+
+ {
+ // EstimatedCoreCount is incremented speculatively when the agent is requested
+ // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
+ auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
+
+ {
+ ZEN_TRACE_CPU("HordeProvisioner::CreateBundles");
+
+ std::lock_guard<std::mutex> BundleLock(m_BundleLock);
+
+ if (!m_BundlesCreated)
+ {
+ const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+
+ std::vector<BundleFile> Files;
+
+#if ZEN_PLATFORM_WINDOWS
+ Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+#elif ZEN_PLATFORM_LINUX
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
+#elif ZEN_PLATFORM_MAC
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
+#endif
+
+ BundleResult Result;
+ if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
+ {
+ ZEN_WARN("failed to create bundle, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ return;
+ }
+
+ m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
+ m_BundlesCreated = true;
+ }
+
+ if (!m_HordeClient)
+ {
+ m_HordeClient = std::make_unique<HordeClient>(m_Config);
+ if (!m_HordeClient->Initialize())
+ {
+ ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ return;
+ }
+ }
+ }
+
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
+
+ m_AgentsRequesting.fetch_add(1);
+ auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
+
+ // Simple backoff: if the last machine request failed, wait up to 5 seconds
+ // before trying again.
+ //
+ // Note however that it's possible that multiple threads enter this code at
+ // the same time if multiple agents are requested at once, and they will all
+ // see the same last failure time and back off accordingly. We might want to
+ // use a semaphore or similar to limit the number of concurrent requests.
+
+ if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
+ {
+ auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
+ const uint64_t ElapsedNs = Now - LastFail;
+ const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
+ if (ElapsedMs < 5000)
+ {
+ const uint64_t WaitMs = 5000 - ElapsedMs;
+ for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+
+ if (Wrapper.ShouldExit.load())
+ {
+ return;
+ }
+ }
+ }
+
+ if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
+ {
+ return;
+ }
+
+ std::string RequestBody = m_HordeClient->BuildRequestBody();
+
+ // Resolve cluster if needed
+ std::string ClusterId = m_Config.Cluster;
+ if (ClusterId == HordeConfig::ClusterAuto)
+ {
+ ClusterInfo Cluster;
+ if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
+ {
+ ZEN_WARN("failed to resolve cluster");
+ m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
+ return;
+ }
+ ClusterId = Cluster.ClusterId;
+ }
+
+ MachineInfo Machine;
+ if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
+ {
+ m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
+ return;
+ }
+
+ m_LastRequestFailTime.store(0);
+
+ if (Wrapper.ShouldExit.load())
+ {
+ return;
+ }
+
+ // Connect to agent and perform handshake
+ Agent = std::make_unique<HordeAgent>(Machine);
+ if (!Agent->IsValid())
+ {
+ ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort());
+ return;
+ }
+
+ if (!Agent->BeginCommunication())
+ {
+ ZEN_WARN("BeginCommunication failed");
+ return;
+ }
+
+ for (auto& [Locator, BundleDir] : m_Bundles)
+ {
+ if (Wrapper.ShouldExit.load())
+ {
+ return;
+ }
+
+ if (!Agent->UploadBinaries(BundleDir, Locator))
+ {
+ ZEN_WARN("UploadBinaries failed");
+ return;
+ }
+ }
+
+ if (Wrapper.ShouldExit.load())
+ {
+ return;
+ }
+
+ // Build command line for remote zenserver
+ std::vector<std::string> ArgStrings;
+ ArgStrings.push_back("compute");
+ ArgStrings.push_back("--http=asio");
+
+ // TEMP HACK - these should be made fully dynamic
+ // these are currently here to allow spawning the compute agent locally
+ // for debugging purposes (i.e with a local Horde Server+Agent setup)
+ ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
+ ArgStrings.push_back("--data-dir=c:\\temp\\123");
+
+ if (!m_OrchestratorEndpoint.empty())
+ {
+ ExtendableStringBuilder<256> CoordArg;
+ CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
+ ArgStrings.emplace_back(CoordArg.ToView());
+ }
+
+ {
+ ExtendableStringBuilder<128> IdArg;
+ IdArg << "--instance-id=horde-" << Machine.LeaseId;
+ ArgStrings.emplace_back(IdArg.ToView());
+ }
+
+ std::vector<const char*> Args;
+ Args.reserve(ArgStrings.size());
+ for (const std::string& Arg : ArgStrings)
+ {
+ Args.push_back(Arg.c_str());
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ const bool UseWine = !Machine.IsWindows;
+ const char* AppName = "zenserver.exe";
+#else
+ const bool UseWine = false;
+ const char* AppName = "zenserver";
+#endif
+
+ Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine);
+
+ ZEN_INFO("remote execution started on [{}:{}] lease={}",
+ Machine.GetConnectionAddress(),
+ Machine.GetConnectionPort(),
+ Machine.LeaseId);
+
+ MachineCoreCount = Machine.LogicalCores;
+ m_EstimatedCoreCount.fetch_add(MachineCoreCount);
+ m_ActiveCoreCount.fetch_add(MachineCoreCount);
+ m_AgentsActive.fetch_add(1);
+ }
+
+ // Agent poll loop
+
+ auto ActiveGuard = MakeGuard([&]() {
+ m_EstimatedCoreCount.fetch_sub(MachineCoreCount);
+ m_ActiveCoreCount.fetch_sub(MachineCoreCount);
+ m_AgentsActive.fetch_sub(1);
+ });
+
+ while (Agent->IsValid() && !Wrapper.ShouldExit.load())
+ {
+ const bool LogOutput = false;
+ if (!Agent->Poll(LogOutput))
+ {
+ break;
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp
new file mode 100644
index 000000000..69766e73e
--- /dev/null
+++ b/src/zenhorde/hordetransport.cpp
@@ -0,0 +1,169 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordetransport.h"
+
+#include <zencore/logging.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# undef SendMessage
+#endif
+
+namespace zen::horde {
+
+// ComputeTransport base
+
+bool
+ComputeTransport::SendMessage(const void* Data, size_t Size)
+{
+ const uint8_t* Ptr = static_cast<const uint8_t*>(Data);
+ size_t Remaining = Size;
+
+ while (Remaining > 0)
+ {
+ const size_t Sent = Send(Ptr, Remaining);
+ if (Sent == 0)
+ {
+ return false;
+ }
+ Ptr += Sent;
+ Remaining -= Sent;
+ }
+
+ return true;
+}
+
+bool
+ComputeTransport::RecvMessage(void* Data, size_t Size)
+{
+ uint8_t* Ptr = static_cast<uint8_t*>(Data);
+ size_t Remaining = Size;
+
+ while (Remaining > 0)
+ {
+ const size_t Received = Recv(Ptr, Remaining);
+ if (Received == 0)
+ {
+ return false;
+ }
+ Ptr += Received;
+ Remaining -= Received;
+ }
+
+ return true;
+}
+
+// TcpComputeTransport - ASIO pimpl
+
+struct TcpComputeTransport::Impl
+{
+ asio::io_context IoContext;
+ asio::ip::tcp::socket Socket;
+
+ Impl() : Socket(IoContext) {}
+};
+
+// Uses ASIO in synchronous mode only — no async operations or io_context::run().
+// The io_context is only needed because ASIO sockets require one to be constructed.
+TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
+: m_Impl(std::make_unique<Impl>())
+, m_Log(zen::logging::Get("horde.transport"))
+{
+ ZEN_TRACE_CPU("TcpComputeTransport::Connect");
+
+ asio::error_code Ec;
+
+ const asio::ip::address Address = asio::ip::make_address(Info.GetConnectionAddress(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message());
+ m_HasErrors = true;
+ return;
+ }
+
+ const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort());
+
+ m_Impl->Socket.connect(Endpoint, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message());
+ m_HasErrors = true;
+ return;
+ }
+
+ // Disable Nagle's algorithm for lower latency
+ m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec);
+}
+
+TcpComputeTransport::~TcpComputeTransport()
+{
+ Close();
+}
+
+bool
+TcpComputeTransport::IsValid() const
+{
+ return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed;
+}
+
+size_t
+TcpComputeTransport::Send(const void* Data, size_t Size)
+{
+ if (!IsValid())
+ {
+ return 0;
+ }
+
+ asio::error_code Ec;
+ const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec);
+
+ if (Ec)
+ {
+ m_HasErrors = true;
+ return 0;
+ }
+
+ return Sent;
+}
+
+size_t
+TcpComputeTransport::Recv(void* Data, size_t Size)
+{
+ if (!IsValid())
+ {
+ return 0;
+ }
+
+ asio::error_code Ec;
+ const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec);
+
+ if (Ec)
+ {
+ return 0;
+ }
+
+ return Received;
+}
+
+void
+TcpComputeTransport::MarkComplete()
+{
+}
+
+void
+TcpComputeTransport::Close()
+{
+ if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open())
+ {
+ asio::error_code Ec;
+ m_Impl->Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ m_Impl->Socket.close(Ec);
+ }
+ m_IsClosed = true;
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h
new file mode 100644
index 000000000..1b178dc0f
--- /dev/null
+++ b/src/zenhorde/hordetransport.h
@@ -0,0 +1,71 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhorde/hordeclient.h>
+
+#include <zencore/logbase.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+
+#if ZEN_PLATFORM_WINDOWS
+# undef SendMessage
+#endif
+
+namespace zen::horde {
+
+/** Abstract base interface for compute transports.
+ *
+ * Matches the UE FComputeTransport pattern. Concrete implementations handle
+ * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides
+ * blocking message helpers on top.
+ */
+class ComputeTransport
+{
+public:
+ virtual ~ComputeTransport() = default;
+
+ virtual bool IsValid() const = 0;
+ virtual size_t Send(const void* Data, size_t Size) = 0;
+ virtual size_t Recv(void* Data, size_t Size) = 0;
+ virtual void MarkComplete() = 0;
+ virtual void Close() = 0;
+
+ /** Blocking send that loops until all bytes are transferred. Returns false on error. */
+ bool SendMessage(const void* Data, size_t Size);
+
+ /** Blocking receive that loops until all bytes are transferred. Returns false on error. */
+ bool RecvMessage(void* Data, size_t Size);
+};
+
+/** TCP socket transport using ASIO.
+ *
+ * Connects to the Horde compute endpoint specified by MachineInfo and provides
+ * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the
+ * header clean.
+ */
+class TcpComputeTransport final : public ComputeTransport
+{
+public:
+ explicit TcpComputeTransport(const MachineInfo& Info);
+ ~TcpComputeTransport() override;
+
+ bool IsValid() const override;
+ size_t Send(const void* Data, size_t Size) override;
+ size_t Recv(void* Data, size_t Size) override;
+ void MarkComplete() override;
+ void Close() override;
+
+private:
+ LoggerRef Log() { return m_Log; }
+
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+ LoggerRef m_Log;
+ bool m_IsClosed = false;
+ bool m_HasErrors = false;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
new file mode 100644
index 000000000..986dd3705
--- /dev/null
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -0,0 +1,425 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordetransportaes.h"
+
+#include <zencore/logging.h>
+#include <zencore/trace.h>
+
+#include <algorithm>
+#include <cstring>
+#include <random>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+# include <bcrypt.h>
+# pragma comment(lib, "Bcrypt.lib")
+#else
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <openssl/evp.h>
+# include <openssl/err.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#endif
+
+namespace zen::horde {
+
+struct AesComputeTransport::CryptoContext
+{
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ bool HasErrors = false;
+
+#if !ZEN_PLATFORM_WINDOWS
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
+#endif
+
+ CryptoContext(const uint8_t (&InKey)[KeySize])
+ {
+ memcpy(Key, InKey, KeySize);
+
+ // The encrypt nonce is randomly initialized and then deterministically mutated
+ // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
+ // the wire (each received message carries its own nonce in the header).
+ std::random_device Rd;
+ std::mt19937 Gen(Rd());
+ std::uniform_int_distribution<int> Dist(0, 255);
+ for (auto& Byte : EncryptNonce)
+ {
+ Byte = static_cast<uint8_t>(Dist(Gen));
+ }
+
+#if !ZEN_PLATFORM_WINDOWS
+ // Drain any stale OpenSSL errors
+ while (ERR_get_error() != 0)
+ {
+ }
+
+ EncCtx = EVP_CIPHER_CTX_new();
+ EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+
+ DecCtx = EVP_CIPHER_CTX_new();
+ EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+#endif
+ }
+
+ ~CryptoContext()
+ {
+#if ZEN_PLATFORM_WINDOWS
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+#else
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
+ }
+
+ void UpdateNonce()
+ {
+ uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
+ N32[0]++;
+ N32[1]--;
+ N32[2] = N32[0] ^ N32[1];
+ }
+
+ // Returns total encrypted message size, or 0 on failure
+ // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
+ {
+ UpdateNonce();
+
+ // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
+ // caching but has some overhead. For our use case (relatively large, infrequent messages)
+ // this is acceptable.
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ uint8_t Tag[TagBytes] = {};
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ NTSTATUS Status =
+ BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ return 0;
+ }
+
+ // Write header: length + nonce
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ // Write tag after ciphertext
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+#else
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ // Write length
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ // Write nonce
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ // Encrypt
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ // Finalize
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ // Get tag
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
+ }
+
+ // Decrypt a message. Returns decrypted data length, or 0 on failure.
+ // Input must be [ciphertext][tag], with nonce provided separately.
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
+
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ return static_cast<int32_t>(PlainLen);
+#else
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ // Set the tag for verification
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ return OutLen + FinalLen;
+#endif
+ }
+};
+
+AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+: m_Crypto(std::make_unique<CryptoContext>(Key))
+, m_Inner(std::move(InnerTransport))
+{
+}
+
+AesComputeTransport::~AesComputeTransport()
+{
+ Close();
+}
+
+bool
+AesComputeTransport::IsValid() const
+{
+ return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
+}
+
+size_t
+AesComputeTransport::Send(const void* Data, size_t Size)
+{
+ ZEN_TRACE_CPU("AesComputeTransport::Send");
+
+ if (!IsValid())
+ {
+ return 0;
+ }
+
+ std::lock_guard<std::mutex> Lock(m_Lock);
+
+ const int32_t DataLength = static_cast<int32_t>(Size);
+ const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+
+ if (m_EncryptBuffer.size() < MessageLength)
+ {
+ m_EncryptBuffer.resize(MessageLength);
+ }
+
+ const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
+ if (EncryptedLen == 0)
+ {
+ return 0;
+ }
+
+ if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
+ {
+ return 0;
+ }
+
+ return Size;
+}
+
+size_t
+AesComputeTransport::Recv(void* Data, size_t Size)
+{
+ if (!IsValid())
+ {
+ return 0;
+ }
+
+ // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
+ // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
+ // and returned on subsequent Recv calls without another decryption round-trip.
+ ZEN_TRACE_CPU("AesComputeTransport::Recv");
+
+ std::lock_guard<std::mutex> Lock(m_Lock);
+
+ if (!m_RemainingData.empty())
+ {
+ const size_t Available = m_RemainingData.size() - m_RemainingOffset;
+ const size_t ToCopy = std::min(Available, Size);
+
+ memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ m_RemainingOffset += ToCopy;
+
+ if (m_RemainingOffset >= m_RemainingData.size())
+ {
+ m_RemainingData.clear();
+ m_RemainingOffset = 0;
+ }
+
+ return ToCopy;
+ }
+
+ // Receive packet header: [length(4B)][nonce(12B)]
+ struct PacketHeader
+ {
+ int32_t DataLength = 0;
+ uint8_t Nonce[NonceBytes] = {};
+ } Header;
+
+ if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
+ {
+ return 0;
+ }
+
+ // Validate DataLength to prevent OOM from malicious/corrupt peers
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
+
+ if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
+ {
+ ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
+ return 0;
+ }
+
+ // Receive ciphertext + tag
+ const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
+
+ if (m_EncryptBuffer.size() < MessageLength)
+ {
+ m_EncryptBuffer.resize(MessageLength);
+ }
+
+ if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
+ {
+ return 0;
+ }
+
+ // Decrypt
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
+
+ // We need a temporary buffer for decryption if we can't decrypt directly into output
+ std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
+
+ const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
+ if (Decrypted == 0)
+ {
+ return 0;
+ }
+
+ memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+
+ // Store remaining data if we couldn't return everything
+ if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ }
+
+ return BytesToReturn;
+}
+
+void
+AesComputeTransport::MarkComplete()
+{
+ if (IsValid())
+ {
+ m_Inner->MarkComplete();
+ }
+}
+
+void
+AesComputeTransport::Close()
+{
+ if (!m_IsClosed)
+ {
+ if (m_Inner && m_Inner->IsValid())
+ {
+ m_Inner->Close();
+ }
+ m_IsClosed = true;
+ }
+}
+
+} // namespace zen::horde
diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h
new file mode 100644
index 000000000..efcad9835
--- /dev/null
+++ b/src/zenhorde/hordetransportaes.h
@@ -0,0 +1,52 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "hordetransport.h"
+
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+namespace zen::horde {
+
+/** AES-256-GCM encrypted transport wrapper.
+ *
+ * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting
+ * all incoming data using AES-256-GCM. The nonce is mutated per message using
+ * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1].
+ *
+ * Wire format per encrypted message:
+ * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)]
+ *
+ * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time).
+ */
+class AesComputeTransport final : public ComputeTransport
+{
+public:
+ AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport);
+ ~AesComputeTransport() override;
+
+ bool IsValid() const override;
+ size_t Send(const void* Data, size_t Size) override;
+ size_t Recv(void* Data, size_t Size) override;
+ void MarkComplete() override;
+ void Close() override;
+
+private:
+ static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size
+ static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size
+
+ struct CryptoContext;
+
+ std::unique_ptr<CryptoContext> m_Crypto;
+ std::unique_ptr<ComputeTransport> m_Inner;
+ std::vector<uint8_t> m_EncryptBuffer;
+ std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv
+ size_t m_RemainingOffset = 0;
+ std::mutex m_Lock;
+ bool m_IsClosed = false;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h
new file mode 100644
index 000000000..201d68b83
--- /dev/null
+++ b/src/zenhorde/include/zenhorde/hordeclient.h
@@ -0,0 +1,116 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhorde/hordeconfig.h>
+
+#include <zencore/logbase.h>
+
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace zen {
+class HttpClient;
+}
+
+namespace zen::horde {
+
+static constexpr size_t NonceSize = 64;
+static constexpr size_t KeySize = 32;
+
+/** Port mapping information returned by Horde for a provisioned machine. */
+struct PortInfo
+{
+ uint16_t Port = 0;
+ uint16_t AgentPort = 0;
+};
+
+/** Describes a provisioned compute machine returned by the Horde API.
+ *
+ * Contains the network address, encryption credentials, and capabilities
+ * needed to establish a compute transport connection to the machine.
+ */
+struct MachineInfo
+{
+ std::string Ip;
+ ConnectionMode Mode = ConnectionMode::Direct;
+ std::string ConnectionAddress; ///< Relay/tunnel address (used when Mode != Direct)
+ uint16_t Port = 0;
+ uint16_t LogicalCores = 0;
+ Encryption EncryptionMode = Encryption::None;
+ uint8_t Nonce[NonceSize] = {}; ///< 64-byte nonce sent during TCP handshake
+ uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES)
+ bool IsWindows = false;
+ std::string LeaseId;
+
+ std::map<std::string, PortInfo> Ports;
+
+ /** Return the address to connect to, accounting for connection mode. */
+ const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
+
+ /** Return the port to connect to, accounting for connection mode and port mapping. */
+ uint16_t GetConnectionPort() const
+ {
+ if (Mode == ConnectionMode::Relay)
+ {
+ auto It = Ports.find("_horde_compute");
+ if (It != Ports.end())
+ {
+ return It->second.Port;
+ }
+ }
+ return Port;
+ }
+
+ bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
+};
+
+/** Result of cluster auto-resolution via the Horde API. */
+struct ClusterInfo
+{
+ std::string ClusterId = "default";
+};
+
+/** HTTP client for the Horde compute REST API.
+ *
+ * Handles cluster resolution and machine provisioning requests. Each call
+ * is synchronous and returns success/failure. Thread safety: individual
+ * methods are not thread-safe; callers must synchronize access.
+ */
+class HordeClient
+{
+public:
+ explicit HordeClient(const HordeConfig& Config);
+ ~HordeClient();
+
+ HordeClient(const HordeClient&) = delete;
+ HordeClient& operator=(const HordeClient&) = delete;
+
+ /** Initialize the underlying HTTP client. Must be called before other methods. */
+ bool Initialize();
+
+ /** Build the JSON request body for cluster resolution and machine requests.
+ * Encodes pool, condition, connection mode, encryption, and port requirements. */
+ std::string BuildRequestBody() const;
+
+ /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */
+ bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
+
+ /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}.
+ * On success, populates OutMachine with connection details and credentials. */
+ bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
+
+ LoggerRef Log() { return m_Log; }
+
+private:
+ bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize);
+
+ HordeConfig m_Config;
+ std::unique_ptr<zen::HttpClient> m_Http;
+ LoggerRef m_Log;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h
new file mode 100644
index 000000000..dd70f9832
--- /dev/null
+++ b/src/zenhorde/include/zenhorde/hordeconfig.h
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhorde/zenhorde.h>
+
+#include <string>
+
+namespace zen::horde {
+
+/** Transport connection mode for Horde compute agents. */
+enum class ConnectionMode
+{
+ Direct, ///< Connect directly to the agent IP
+ Tunnel, ///< Connect through a Horde tunnel relay
+ Relay, ///< Connect through a Horde relay with port mapping
+};
+
+/** Transport encryption mode for Horde compute channels. */
+enum class Encryption
+{
+ None, ///< No encryption
+ AES, ///< AES-256-GCM encryption (required for Relay mode)
+};
+
+/** Configuration for connecting to an Epic Horde compute cluster.
+ *
+ * Specifies the Horde server URL, authentication token, pool selection,
+ * connection mode, and resource limits. Used by HordeClient and HordeProvisioner.
+ */
+struct HordeConfig
+{
+ static constexpr const char* ClusterDefault = "default";
+ static constexpr const char* ClusterAuto = "_auto";
+
+ bool Enabled = false; ///< Whether Horde provisioning is active
+ std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
+ std::string AuthToken; ///< Authentication token for the Horde API
+ std::string Pool; ///< Pool name to request machines from
+ std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
+ std::string Condition; ///< Agent filter expression for machine selection
+ std::string HostAddress; ///< Address that provisioned agents use to connect back to us
+ std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
+ uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
+
+ int MaxCores = 2048;
+ bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
+ ConnectionMode Mode = ConnectionMode::Direct;
+ Encryption EncryptionMode = Encryption::None;
+
+ /** Validate the configuration. Returns false if the configuration is invalid
+ * (e.g. Relay mode without AES encryption). */
+ bool Validate() const;
+};
+
+const char* ToString(ConnectionMode Mode);
+const char* ToString(Encryption Enc);
+
+bool FromString(ConnectionMode& OutMode, std::string_view Str);
+bool FromString(Encryption& OutEnc, std::string_view Str);
+
+} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h
new file mode 100644
index 000000000..4e2e63bbd
--- /dev/null
+++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h
@@ -0,0 +1,110 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhorde/hordeconfig.h>
+
+#include <zencore/logbase.h>
+
+#include <atomic>
+#include <cstdint>
+#include <filesystem>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <vector>
+
+namespace zen::horde {
+
+class HordeClient;
+
+/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */
+struct ProvisioningStats
+{
+ uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores)
+ uint32_t EstimatedCoreCount = 0; ///< Cores expected once pending requests complete
+ uint32_t ActiveCoreCount = 0; ///< Cores on machines that are currently running zenserver
+ uint32_t AgentsActive = 0; ///< Number of agents with a running remote process
+ uint32_t AgentsRequesting = 0; ///< Number of agents currently requesting a machine from Horde
+};
+
+/** Multi-agent lifecycle manager for Horde worker provisioning.
+ *
+ * Provisions remote compute workers by requesting machines from the Horde API,
+ * connecting via the Horde compute transport protocol, uploading the zenserver
+ * binary, and executing it remotely. Each provisioned machine runs zenserver
+ * in compute mode, which announces itself back to the orchestrator.
+ *
+ * Spawns one thread per agent. Each thread handles the full lifecycle:
+ * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption ->
+ * channel setup -> binary upload -> remote execution -> poll until exit.
+ *
+ * Thread safety: SetTargetCoreCount and GetStats may be called from any thread.
+ */
+class HordeProvisioner
+{
+public:
+ /** Construct a provisioner.
+ * @param Config Horde connection and pool configuration.
+ * @param BinariesPath Directory containing the zenserver binary to upload.
+ * @param WorkingDir Local directory for bundle staging and working files.
+ * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */
+ HordeProvisioner(const HordeConfig& Config,
+ const std::filesystem::path& BinariesPath,
+ const std::filesystem::path& WorkingDir,
+ std::string_view OrchestratorEndpoint);
+
+ /** Signals all agent threads to exit and joins them. */
+ ~HordeProvisioner();
+
+ HordeProvisioner(const HordeProvisioner&) = delete;
+ HordeProvisioner& operator=(const HordeProvisioner&) = delete;
+
+ /** Set the target number of cores to provision.
+ * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the
+ * estimated core count is below the target. Also joins any finished
+ * agent threads. */
+ void SetTargetCoreCount(uint32_t Count);
+
+ /** Return a snapshot of the current provisioning counters. */
+ ProvisioningStats GetStats() const;
+
+ uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); }
+ uint32_t GetAgentCount() const;
+
+private:
+ LoggerRef Log() { return m_Log; }
+
+ struct AgentWrapper;
+
+ void RequestAgent();
+ void ThreadAgent(AgentWrapper& Wrapper);
+
+ HordeConfig m_Config;
+ std::filesystem::path m_BinariesPath;
+ std::filesystem::path m_WorkingDir;
+ std::string m_OrchestratorEndpoint;
+
+ std::unique_ptr<HordeClient> m_HordeClient;
+
+ std::mutex m_BundleLock;
+ std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs
+ bool m_BundlesCreated = false;
+
+ mutable std::mutex m_AgentsLock;
+ std::vector<std::unique_ptr<AgentWrapper>> m_Agents;
+
+ std::atomic<uint64_t> m_LastRequestFailTime{0};
+ std::atomic<uint32_t> m_TargetCoreCount{0};
+ std::atomic<uint32_t> m_EstimatedCoreCount{0};
+ std::atomic<uint32_t> m_ActiveCoreCount{0};
+ std::atomic<uint32_t> m_AgentsActive{0};
+ std::atomic<uint32_t> m_AgentsRequesting{0};
+ std::atomic<bool> m_AskForAgents{true};
+
+ LoggerRef m_Log;
+
+ static constexpr uint32_t EstimatedCoresPerAgent = 32;
+};
+
+} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/zenhorde.h b/src/zenhorde/include/zenhorde/zenhorde.h
new file mode 100644
index 000000000..35147ff75
--- /dev/null
+++ b/src/zenhorde/include/zenhorde/zenhorde.h
@@ -0,0 +1,9 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if !defined(ZEN_WITH_HORDE)
+# define ZEN_WITH_HORDE 1
+#endif
diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua
new file mode 100644
index 000000000..48d028e86
--- /dev/null
+++ b/src/zenhorde/xmake.lua
@@ -0,0 +1,22 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zenhorde')
+ set_kind("static")
+ set_group("libs")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_deps("zencore", "zenhttp", "zencompute", "zenutil")
+ add_packages("asio", "json11")
+
+ if is_plat("windows") then
+ add_syslinks("Ws2_32", "Bcrypt")
+ end
+
+ if is_plat("linux") or is_plat("macosx") then
+ add_packages("openssl")
+ end
+
+ if is_os("macosx") then
+ add_cxxflags("-Wno-deprecated-declarations")
+ end
diff --git a/src/zenhttp-test/zenhttp-test.cpp b/src/zenhttp-test/zenhttp-test.cpp
index c18759beb..b4b406ac8 100644
--- a/src/zenhttp-test/zenhttp-test.cpp
+++ b/src/zenhttp-test/zenhttp-test.cpp
@@ -1,44 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/memory/newdelete.h>
-#include <zencore/trace.h>
+#include <zencore/testing.h>
#include <zenhttp/zenhttp.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
+#include <zencore/memory/newdelete.h>
int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zenhttp_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zenhttp-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
-
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zenhttp-test", zen::zenhttp_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp
index 38e7586ad..23bbc17e8 100644
--- a/src/zenhttp/auth/oidc.cpp
+++ b/src/zenhttp/auth/oidc.cpp
@@ -32,6 +32,25 @@ namespace details {
using namespace std::literals;
+static std::string
+FormUrlEncode(std::string_view Input)
+{
+ std::string Result;
+ Result.reserve(Input.size());
+ for (char C : Input)
+ {
+ if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~')
+ {
+ Result.push_back(C);
+ }
+ else
+ {
+ Result.append(fmt::format("%{:02X}", static_cast<uint8_t>(C)));
+ }
+ }
+ return Result;
+}
+
OidcClient::OidcClient(const OidcClient::Options& Options)
{
m_BaseUrl = std::string(Options.BaseUrl);
@@ -67,6 +86,8 @@ OidcClient::Initialize()
.TokenEndpoint = Json["token_endpoint"].string_value(),
.UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
.RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .EndSessionEndpoint = Json["end_session_endpoint"].string_value(),
+ .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(),
.JwksUri = Json["jwks_uri"].string_value(),
.SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
.SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
@@ -81,7 +102,8 @@ OidcClient::Initialize()
OidcClient::RefreshTokenResult
OidcClient::RefreshToken(std::string_view RefreshToken)
{
- const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+ const std::string Body =
+ fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId));
HttpClient Http{m_Config.TokenEndpoint};
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
index 47425e014..6f4c67dd0 100644
--- a/src/zenhttp/clients/httpclientcommon.cpp
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -142,7 +142,10 @@ namespace detail {
DataSize -= CopySize;
if (m_CacheBufferOffset == CacheBufferSize)
{
- AppendData(m_CacheBuffer, CacheBufferSize);
+ if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize))
+ {
+ return Ec;
+ }
if (DataSize > 0)
{
ZEN_ASSERT(DataSize < CacheBufferSize);
@@ -382,6 +385,177 @@ namespace detail {
return Result;
}
+ MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {}
+
+ bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue)
+ {
+ std::string LowerCaseValue = ToLower(ContentTypeHeaderValue);
+ if (LowerCaseValue.starts_with("multipart/byteranges"))
+ {
+ size_t BoundaryPos = LowerCaseValue.find("boundary=");
+ if (BoundaryPos != std::string::npos)
+ {
+ // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string
+ std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9);
+ size_t BoundaryEnd = std::string::npos;
+ while (!BoundaryName.empty() && BoundaryName[0] == ' ')
+ {
+ BoundaryName = BoundaryName.substr(1);
+ }
+ if (!BoundaryName.empty())
+ {
+ if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"')
+ {
+ BoundaryEnd = BoundaryName.find('"', 1);
+ if (BoundaryEnd != std::string::npos)
+ {
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1)));
+ return true;
+ }
+ }
+ else
+ {
+ BoundaryEnd = BoundaryName.find_first_of(" \r\n");
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd)));
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ void MultipartBoundaryParser::ParseInput(std::string_view data)
+ {
+ const char* InputPtr = data.data();
+ size_t InputLength = data.length();
+ size_t ScanPos = 0;
+ while (ScanPos < InputLength)
+ {
+ const char ScanChar = InputPtr[ScanPos];
+ if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length()))
+ {
+ BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ break;
+ }
+ }
+
+ BoundaryHeader.Append(ScanChar);
+
+ HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+
+ if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset();
+ const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset();
+ const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset;
+ std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength));
+
+ uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1;
+
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType = HttpContentType::kBinary;
+
+ ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) {
+ const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line);
+ const std::string_view Key = KeyAndValue.first;
+ const std::string_view Value = KeyAndValue.second;
+ if (Key == "Content-Range")
+ {
+ std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value);
+ if (ContentRange.second != 0)
+ {
+ RangeOffset = ContentRange.first;
+ RangeLength = ContentRange.second;
+ }
+ }
+ else if (Key == "Content-Type")
+ {
+ ContentType = ParseContentType(Value);
+ }
+
+ return true;
+ });
+
+ if (RangeLength > 0)
+ {
+ Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload,
+ .RangeOffset = RangeOffset,
+ .RangeLength = RangeLength,
+ .ContentType = ContentType});
+ }
+
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ }
+ }
+ else
+ {
+ BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ }
+ ScanPos++;
+ }
+ PayloadOffset += InputLength;
+ }
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString)
+ {
+ size_t DelimiterPos = HeaderString.find(':');
+ if (DelimiterPos != std::string::npos)
+ {
+ std::string_view Key = HeaderString.substr(0, DelimiterPos);
+ constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
+ Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
+
+ std::string_view Value = HeaderString.substr(DelimiterPos + 1);
+ Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
+ Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
+ return std::make_pair(Key, Value);
+ }
+ return std::make_pair(HeaderString, std::string_view{});
+ }
+
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value)
+ {
+ if (Value.starts_with("bytes "))
+ {
+ size_t RangeSplitPos = Value.find('-', 6);
+ if (RangeSplitPos != std::string::npos)
+ {
+ size_t RangeEndLength = Value.find('/', RangeSplitPos + 1);
+ if (RangeEndLength == std::string::npos)
+ {
+ RangeEndLength = Value.length() - (RangeSplitPos + 1);
+ }
+ else
+ {
+ RangeEndLength = RangeEndLength - (RangeSplitPos + 1);
+ }
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ uint64_t RangeOffset = RequestedRangeStart.value();
+ uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1;
+ return std::make_pair(RangeOffset, RangeLength);
+ }
+ }
+ }
+ return {0, 0};
+ }
+
} // namespace detail
} // namespace zen
@@ -423,6 +597,8 @@ namespace testutil {
} // namespace testutil
+TEST_SUITE_BEGIN("http.httpclientcommon");
+
TEST_CASE("BufferedReadFileStream")
{
ScopedTemporaryDirectory TmpDir;
@@ -470,5 +646,150 @@ TEST_CASE("CompositeBufferReadStream")
CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
}
+TEST_CASE("MultipartBoundaryParser")
+{
+ uint64_t Range1Offset = 2638;
+ uint64_t Range1Length = (5111437 - Range1Offset) + 1;
+
+ uint64_t Range2Offset = 5118199;
+ uint64_t Range2Length = (9147741 - Range2Offset) + 1;
+
+ std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229";
+ std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\"";
+
+ {
+ std::string_view Example1 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/44369878\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample1;
+ ParserExample1.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 7;
+ for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow)
+ {
+ ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample1.Boundaries.size() == 2);
+
+ CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example2 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample2;
+ ParserExample2.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow));
+ ParserExample2.ParseInput(Window);
+ }
+
+ CHECK(ParserExample2.Boundaries.size() == 2);
+
+ CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example3 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita";
+
+ detail::MultipartBoundaryParser ParserExample3;
+ ParserExample3.Init(ContentTypeHeaderValue2);
+
+ const size_t InputWindow = 31;
+ for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow)
+ {
+ ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample3.Boundaries.size() == 2);
+
+ CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example4 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "Not: really\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--000000000bait0019229\r\n"
+ "\r\n--00\r\n--000000000bait001922\r\n"
+ "\r\n\r\n\r\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "Content-Type: application/x-ue-comp\r\n"
+ "ditaditadita"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n---\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample4;
+ ParserExample4.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow));
+ ParserExample4.ParseInput(Window);
+ }
+
+ CHECK(ParserExample4.Boundaries.size() == 2);
+
+ CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length);
+ }
+}
+
+TEST_SUITE_END();
+
} // namespace zen
#endif
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
index 1d0b7f9ea..5ed946541 100644
--- a/src/zenhttp/clients/httpclientcommon.h
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/compositebuffer.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhttp/httpclient.h>
@@ -87,7 +88,7 @@ namespace detail {
std::error_code Write(std::string_view DataString);
IoBuffer DetachToIoBuffer();
IoBuffer BorrowIoBuffer();
- inline uint64_t GetSize() const { return m_WriteOffset; }
+ inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; }
void ResetWritePos(uint64_t WriteOffset);
private:
@@ -143,6 +144,118 @@ namespace detail {
uint64_t m_BytesLeftInSegment;
};
+ class IncrementalStringMatcher
+ {
+ public:
+ enum class EMatchState
+ {
+ None,
+ Partial,
+ Complete
+ };
+
+ EMatchState MatchState = EMatchState::None;
+
+ IncrementalStringMatcher() {}
+
+ IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString))
+ {
+ RawMatchString = MatchString.data();
+ }
+
+ void Init(std::string&& InMatchString)
+ {
+ MatchString = std::move(InMatchString);
+ RawMatchString = MatchString.data();
+ }
+
+ inline void Reset()
+ {
+ MatchLength = 0;
+ MatchStartOffset = 0;
+ MatchState = EMatchState::None;
+ }
+
+ inline uint64_t GetMatchEndOffset() const
+ {
+ if (MatchState == EMatchState::Complete)
+ {
+ return MatchStartOffset + MatchString.length();
+ }
+ return 0;
+ }
+
+ inline uint64_t GetMatchStartOffset() const
+ {
+ ZEN_ASSERT(MatchState == EMatchState::Complete);
+ return MatchStartOffset;
+ }
+
+ void Match(uint64_t Offset, char C)
+ {
+ ZEN_ASSERT_SLOW(RawMatchString != nullptr);
+
+ if (MatchState == EMatchState::Complete)
+ {
+ Reset();
+ }
+ if (C == RawMatchString[MatchLength])
+ {
+ if (MatchLength == 0)
+ {
+ MatchStartOffset = Offset;
+ }
+ MatchLength++;
+ if (MatchLength == MatchString.length())
+ {
+ MatchState = EMatchState::Complete;
+ }
+ else
+ {
+ MatchState = EMatchState::Partial;
+ }
+ }
+ else if (MatchLength != 0)
+ {
+ Reset();
+ Match(Offset, C);
+ }
+ else
+ {
+ Reset();
+ }
+ }
+ inline const std::string& GetMatchString() const { return MatchString; }
+
+ private:
+ std::string MatchString;
+ const char* RawMatchString = nullptr;
+ uint64_t MatchLength = 0;
+
+ uint64_t MatchStartOffset = 0;
+ };
+
+ class MultipartBoundaryParser
+ {
+ public:
+ std::vector<HttpClient::Response::MultipartBoundary> Boundaries;
+
+ MultipartBoundaryParser();
+ bool Init(const std::string_view ContentTypeHeaderValue);
+ void ParseInput(std::string_view data);
+
+ private:
+ IncrementalStringMatcher BoundaryBeginMatcher;
+ IncrementalStringMatcher BoundaryEndMatcher;
+ IncrementalStringMatcher HeaderEndMatcher;
+
+ ExtendableStringBuilder<64> BoundaryHeader;
+ uint64_t PayloadOffset = 0;
+ };
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString);
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value);
+
} // namespace detail
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index 5d92b3b6b..14e40b02a 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -12,6 +12,7 @@
#include <zencore/session.h>
#include <zencore/stream.h>
#include <zenhttp/packageformat.h>
+#include <algorithm>
namespace zen {
@@ -23,6 +24,21 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti
static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+bool
+HttpClient::ErrorContext::IsConnectionError() const
+{
+ switch (static_cast<cpr::ErrorCode>(ErrorCode))
+ {
+ case cpr::ErrorCode::CONNECTION_FAILURE:
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
+ return true;
+ default:
+ return false;
+ }
+}
+
// If we want to support different HTTP client implementations then we'll need to make this more abstract
HttpClientError::ResponseClass
@@ -149,6 +165,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri,
{
}
+bool
+CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ // Quiet
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
CprHttpClient::~CprHttpClient()
{
ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient");
@@ -162,10 +190,11 @@ CprHttpClient::~CprHttpClient()
}
HttpClient::Response
-CprHttpClient::ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload)
+CprHttpClient::ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
// This ends up doing a memcpy, would be good to get rid of it by streaming results
// into buffer directly
@@ -174,30 +203,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId,
if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
{
const HttpContentType ContentType = ParseContentType(It->second);
-
ResponseBuffer.SetContentType(ContentType);
}
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
-
- if (!Quiet)
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
{
- if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ if (ShouldLogErrorCode(WorkResponseCode))
{
ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse);
}
}
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
return HttpClient::Response{.StatusCode = WorkResponseCode,
.ResponsePayload = std::move(ResponseBuffer),
.Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed};
+ .ElapsedSeconds = HttpResponse.elapsed,
+ .Ranges = std::move(BoundaryPositions)};
}
HttpClient::Response
-CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload)
+CprHttpClient::CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
if (HttpResponse.error)
@@ -235,7 +271,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe
}
else
{
- return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload));
+ return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
}
}
@@ -346,8 +382,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -385,8 +420,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -621,7 +655,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke
ResponseBuffer.SetContentType(ContentType);
}
- return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)};
}
//////////////////////////////////////////////////////////////////////////
@@ -896,236 +930,287 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF
std::string PayloadString;
std::unique_ptr<detail::TempPayloadFile> PayloadFile;
- cpr::Response Response = DoWithRetry(
- m_SessionId,
- [&]() {
- auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> {
- size_t DelimiterPos = header.find(':');
- if (DelimiterPos != std::string::npos)
- {
- std::string Key = header.substr(0, DelimiterPos);
- constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
- Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
- Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
-
- std::string Value = header.substr(DelimiterPos + 1);
- Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
- Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
-
- return std::make_pair(Key, Value);
- }
- return std::make_pair(header, "");
- };
-
- auto DownloadCallback = [&](std::string data, intptr_t) {
- if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
- {
- return false;
- }
- if (PayloadFile)
- {
- ZEN_ASSERT(PayloadString.empty());
- std::error_code Ec = PayloadFile->Write(data);
- if (Ec)
- {
- ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- return false;
- }
- }
- else
- {
- PayloadString.append(data);
- }
- return true;
- };
-
- uint64_t RequestedContentLength = (uint64_t)-1;
- if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
- {
- if (RangeIt->second.starts_with("bytes"))
- {
- size_t RangeStartPos = RangeIt->second.find('=', 5);
- if (RangeStartPos != std::string::npos)
- {
- RangeStartPos++;
- size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos);
- if (RangeSplitPos != std::string::npos)
- {
- std::optional<size_t> RequestedRangeStart =
- ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos));
- std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1));
- if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
- {
- RequestedContentLength = RequestedRangeEnd.value() - 1;
- }
- }
- }
- }
- }
-
- cpr::Response Response;
- {
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (Header.first == "Content-Length"sv)
- {
- std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
- if (ContentLength.has_value())
- {
- if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
- {
- PayloadFile = std::make_unique<detail::TempPayloadFile>();
- std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
- if (Ec)
- {
- ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- PayloadFile.reset();
- }
- }
- else
- {
- PayloadString.reserve(ContentLength.value());
- }
- }
- }
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
- return 1;
- };
-
- Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- }
- if (m_ConnectionSettings.AllowResume)
- {
- auto SupportsRanges = [](const cpr::Response& Response) -> bool {
- if (Response.header.find("Content-Range") != Response.header.end())
- {
- return true;
- }
- if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
- {
- return It->second == "bytes"sv;
- }
- return false;
- };
-
- auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool {
- if (ShouldRetry(Response))
- {
- return SupportsRanges(Response);
- }
- return false;
- };
-
- if (ShouldResume(Response))
- {
- auto It = Response.header.find("Content-Length");
- if (It != Response.header.end())
- {
- uint64_t ContentLength = RequestedContentLength;
- if (ContentLength == uint64_t(-1))
- {
- if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
- {
- ContentLength = ParsedContentLength.value();
- }
- }
-
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
-
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
-
- if (Header.first == "Content-Range"sv)
- {
- if (Header.second.starts_with("bytes "sv))
- {
- size_t RangeStartEnd = Header.second.find('-', 6);
- if (RangeStartEnd != std::string::npos)
- {
- const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
- if (Start)
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
- if (Start.value() == DownloadedSize)
- {
- return 1;
- }
- else if (Start.value() > DownloadedSize)
- {
- return 0;
- }
- if (PayloadFile)
- {
- PayloadFile->ResetWritePos(Start.value());
- }
- else
- {
- PayloadString = PayloadString.substr(0, Start.value());
- }
- return 1;
- }
- }
- }
- return 0;
- }
- return 1;
- };
-
- KeyValueMap HeadersWithRange(AdditionalHeader);
- do
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
-
- std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
- if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
- {
- if (RangeIt->second == Range)
- {
- // If we didn't make any progress, abort
- break;
- }
- }
- HeadersWithRange.Entries.insert_or_assign("Range", Range);
-
- Session Sess =
- AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- ReceivedHeaders.clear();
- } while (ShouldResume(Response));
- }
- }
- }
-
- if (!PayloadString.empty())
- {
- Response.text = std::move(PayloadString);
- }
- return Response;
- },
- PayloadFile);
-
- return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ cpr::Response Response = DoWithRetry(
+ m_SessionId,
+ [&]() {
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+
+ if (IsMultiRangeResponse)
+ {
+ BoundaryParser.ParseInput(data);
+ }
+
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ cpr::Response Response;
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (Header.first == "Content-Length"sv)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (Header.first == "Content-Type")
+ {
+ IsMultiRangeResponse = BoundaryParser.Init(Header.second);
+ if (!IsMultiRangeResponse)
+ {
+ ContentType = ParseContentType(Header.second);
+ }
+ }
+ else if (Header.first == "Content-Range")
+ {
+ if (!IsMultiRangeResponse)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second);
+ if (Range.second != 0)
+ {
+ BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = ContentType});
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ }
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const cpr::Response& Response) -> bool {
+ if (Response.header.find("Content-Range") != Response.header.end())
+ {
+ return true;
+ }
+ if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
+ {
+ return It->second == "bytes"sv;
+ }
+ return false;
+ };
+
+ auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(Response))
+ {
+ return SupportsRanges(Response);
+ }
+ return false;
+ };
+
+ if (ShouldResume(Response))
+ {
+ auto It = Response.header.find("Content-Length");
+ if (It != Response.header.end())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+
+ if (Header.first == "Content-Range"sv)
+ {
+ if (Header.second.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Header.second.find('-', 6);
+ if (RangeStartEnd != std::string::npos)
+ {
+ const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+ if (Start.value() == DownloadedSize)
+ {
+ return 1;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (PayloadFile)
+ {
+ PayloadFile->ResetWritePos(Start.value());
+ }
+ else
+ {
+ PayloadString = PayloadString.substr(0, Start.value());
+ }
+ return 1;
+ }
+ }
+ }
+ return 0;
+ }
+ return 1;
+ };
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ // If we didn't make any progress, abort
+ break;
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ ReceivedHeaders.clear();
+ } while (ShouldResume(Response));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Response),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
}
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
index 40af53b5d..752d91add 100644
--- a/src/zenhttp/clients/httpclientcpr.h
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -155,14 +155,19 @@ private:
std::function<cpr::Response()>&& Func,
std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; });
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
- HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload);
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
- HttpClient::Response ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload);
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
};
} // namespace zen
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
new file mode 100644
index 000000000..9497dadb8
--- /dev/null
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -0,0 +1,566 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpwsclient.h>
+
+#include "../servers/wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <random>
+#include <thread>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpWsClient::Impl
+{
+ Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ ~Impl()
+ {
+ // Release work guard so io_context::run() can return
+ m_WorkGuard.reset();
+
+ // Close the socket to cancel pending async ops
+ if (m_Socket)
+ {
+ asio::error_code Ec;
+ m_Socket->close(Ec);
+ }
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+
+ void ParseUrl(std::string_view Url)
+ {
+ // Expected format: ws://host:port/path
+ if (Url.substr(0, 5) == "ws://")
+ {
+ Url.remove_prefix(5);
+ }
+
+ auto SlashPos = Url.find('/');
+ std::string_view HostPort;
+ if (SlashPos != std::string_view::npos)
+ {
+ HostPort = Url.substr(0, SlashPos);
+ m_Path = std::string(Url.substr(SlashPos));
+ }
+ else
+ {
+ HostPort = Url;
+ m_Path = "/";
+ }
+
+ auto ColonPos = HostPort.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ m_Host = std::string(HostPort.substr(0, ColonPos));
+ m_Port = std::string(HostPort.substr(ColonPos + 1));
+ }
+ else
+ {
+ m_Host = std::string(HostPort);
+ m_Port = "80";
+ }
+ }
+
+ void Connect()
+ {
+ if (m_OwnedIoContext)
+ {
+ m_WorkGuard = std::make_unique<asio::io_context::work>(m_IoContext);
+ m_IoThread = std::thread([this] { m_IoContext.run(); });
+ }
+
+ asio::post(m_IoContext, [this] { DoResolve(); });
+ }
+
+ void DoResolve()
+ {
+ m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
+
+ m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "resolve failed");
+ return;
+ }
+
+ DoConnect(Results);
+ });
+ }
+
+ void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
+ {
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
+ if (m_Socket)
+ {
+ asio::error_code CloseEc;
+ m_Socket->close(CloseEc);
+ }
+ }
+ });
+
+ asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+
+ void DoHandshake()
+ {
+ // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded)
+ uint8_t KeyBytes[16];
+ {
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ for (int i = 0; i < 4; ++i)
+ {
+ uint32_t Val = s_Rng();
+ std::memcpy(KeyBytes + i * 4, &Val, 4);
+ }
+ }
+
+ char KeyBase64[Base64::GetEncodedDataSize(16) + 1];
+ uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64);
+ KeyBase64[KeyLen] = '\0';
+ m_WebSocketKey = std::string(KeyBase64, KeyLen);
+
+ // Build the HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << m_Path << " HTTP/1.1\r\n"
+ << "Host: " << m_Host << ":" << m_Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n"
+ << "Sec-WebSocket-Version: 13\r\n";
+
+ // Add Authorization header if access token provider is set
+ if (m_Settings.AccessTokenProvider)
+ {
+ HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
+ if (Token.IsValid())
+ {
+ Request << "Authorization: Bearer " << Token.Value << "\r\n";
+ }
+ }
+
+ Request << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
+
+ asio::async_write(*m_Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ }
+
+ void DoReadHandshakeResponse()
+ {
+ asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
+
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
+
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
+
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
+
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
+
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Read loop
+ //
+
+ void EnqueueRead()
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
+ }
+
+ void OnDataReceived(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+ }
+
+ void ProcessReceivedData()
+ {
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size);
+ if (!Frame.IsValid)
+ {
+ break;
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWsMessage(Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with masked pong
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason =
+ std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo masked close frame if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWsClose(Code, Reason);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Write queue
+ //
+
+ void EnqueueWrite(std::vector<uint8_t> Frame)
+ {
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+ }
+
+ void FlushWriteQueue()
+ {
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ }
+
+ void OnWriteComplete(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Public operations
+ //
+
+ void SendText(std::string_view Text)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void SendBinary(std::span<const uint8_t> Data)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void DoClose(uint16_t Code, std::string_view Reason)
+ {
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ IWsClientHandler& m_Handler;
+ HttpWsClientSettings m_Settings;
+ LoggerRef m_Log;
+
+ std::string m_Host;
+ std::string m_Port;
+ std::string m_Path;
+
+ // io_context: owned (standalone) or external (shared)
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ std::unique_ptr<asio::io_context::work> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // Connection state
+ std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
+
+ // Write queue
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{false};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, Settings))
+{
+}
+
+HttpWsClient::HttpWsClient(std::string_view Url,
+ IWsClientHandler& Handler,
+ asio::io_context& IoContext,
+ const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings))
+{
+}
+
+HttpWsClient::~HttpWsClient() = default;
+
+void
+HttpWsClient::Connect()
+{
+ m_Impl->Connect();
+}
+
+void
+HttpWsClient::SendText(std::string_view Text)
+{
+ m_Impl->SendText(Text);
+}
+
+void
+HttpWsClient::SendBinary(std::span<const uint8_t> Data)
+{
+ m_Impl->SendBinary(Data);
+}
+
+void
+HttpWsClient::Close(uint16_t Code, std::string_view Reason)
+{
+ m_Impl->DoClose(Code, Reason);
+}
+
+bool
+HttpWsClient::IsOpen() const
+{
+ return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index 43e9fb468..281d512cf 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -21,9 +21,17 @@
#include "clients/httpclientcommon.h"
+#include <numeric>
+
#if ZEN_WITH_TESTS
+# include <zencore/scopeguard.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zenhttp/security/passwordsecurityfilter.h>
+# include "servers/httpasio.h"
+# include "servers/httpsys.h"
+
+# include <thread>
#endif // ZEN_WITH_TESTS
namespace zen {
@@ -96,6 +104,44 @@ HttpClientBase::GetAccessToken()
//////////////////////////////////////////////////////////////////////////
+std::vector<std::pair<uint64_t, uint64_t>>
+HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const
+{
+ if (Ranges.empty())
+ {
+ return {};
+ }
+
+ std::vector<std::pair<uint64_t, uint64_t>> Result;
+ Result.reserve(OffsetAndLengthPairs.size());
+
+ auto BoundaryIt = Ranges.begin();
+ auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin();
+ while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end())
+ {
+ uint64_t Offset = OffsetAndLengthPairIt->first;
+ uint64_t Length = OffsetAndLengthPairIt->second;
+ while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength)
+ {
+ BoundaryIt++;
+ if (BoundaryIt == Ranges.end())
+ {
+ throw std::runtime_error("HttpClient::Response can not fulfill requested range");
+ }
+ }
+ if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset)
+ {
+ throw std::runtime_error("HttpClient::Response can not fulfill requested range");
+ }
+ uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset;
+ uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange;
+ Result.emplace_back(std::make_pair(RangePayloadOffset, Length));
+
+ OffsetAndLengthPairIt++;
+ }
+ return Result;
+}
+
CbObject
HttpClient::Response::AsObject() const
{
@@ -334,10 +380,55 @@ HttpClient::Authenticate()
return m_Inner->Authenticate();
}
+LatencyTestResult
+MeasureLatency(HttpClient& Client, std::string_view Url)
+{
+ std::vector<double> MeasurementTimes;
+ std::string ErrorMessage;
+
+ for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++)
+ {
+ HttpClient::Response MeasureResponse = Client.Get(Url);
+ if (MeasureResponse.IsSuccess())
+ {
+ MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds);
+ Sleep(5);
+ }
+ else
+ {
+ ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url));
+
+ // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable.
+ // Bail out immediately — retrying will just burn the connect timeout each time.
+ if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError())
+ {
+ break;
+ }
+ }
+ }
+
+ if (MeasurementTimes.empty())
+ {
+ return {.Success = false, .FailureReason = ErrorMessage};
+ }
+
+ if (MeasurementTimes.size() > 2)
+ {
+ std::sort(MeasurementTimes.begin(), MeasurementTimes.end());
+ MeasurementTimes.pop_back(); // Remove the worst time
+ }
+
+ double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size();
+
+ return {.Success = true, .LatencySeconds = AverageLatency};
+}
+
//////////////////////////////////////////////////////////////////////////
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.httpclient");
+
TEST_CASE("responseformat")
{
using namespace std::literals;
@@ -388,8 +479,366 @@ TEST_CASE("httpclient")
{
using namespace std::literals;
- SUBCASE("client") {}
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ if (HttpServiceRequest.IsLocalMachineRequest())
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+ else
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey stranger");
+ }
+ }
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ SUBCASE("asio")
+ {
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ {
+ HttpClient Client(fmt::format("127.0.0.1:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ if (IsIPv6Capable())
+ {
+ HttpClient Client(fmt::format("[::1]:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ {
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+# if 0
+ {
+ HttpClient Client(fmt::format("10.24.101.77:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+ Sleep(20000);
+# endif // 0
+ AsioServer->RequestExit();
+ }
+ }
+
+# if ZEN_PLATFORM_WINDOWS
+ SUBCASE("httpsys")
+ {
+ Ref<HttpServer> HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = false});
+
+ int Port = HttpSysServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ HttpSysServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { HttpSysServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ HttpSysServer->Close();
+ });
+
+ if (true)
+ {
+ HttpClient Client(fmt::format("127.0.0.1:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ if (IsIPv6Capable())
+ {
+ HttpClient Client(fmt::format("[::1]:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ {
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+# if 0
+ {
+ HttpClient Client(fmt::format("10.24.101.77:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+ Sleep(20000);
+# endif // 0
+ HttpSysServer->RequestExit();
+ }
+ }
+# endif // ZEN_PLATFORM_WINDOWS
+}
+
+TEST_CASE("httpclient.requestfilter")
+{
+ using namespace std::literals;
+
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_filter");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_forbid");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ class MyFilterImpl : public IHttpRequestFilter
+ {
+ public:
+ virtual Result FilterRequest(HttpServerRequest& Request)
+ {
+ if (Request.RelativeUri() == "should_filter")
+ {
+ Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you");
+ return Result::ResponseSent;
+ }
+ else if (Request.RelativeUri() == "should_forbid")
+ {
+ return Result::Forbidden;
+ }
+ return Result::Accepted;
+ }
+ };
+
+ MyFilterImpl MyFilter;
+
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ AsioServer->SetHttpRequestFilter(&MyFilter);
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response YoResponse = Client.Get("/test/yo");
+ CHECK(YoResponse.IsSuccess());
+ CHECK_EQ(YoResponse.AsText(), "hey family");
+
+ HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter");
+ CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed);
+ CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you");
+
+ HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid");
+ CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden);
+
+ AsioServer->RequestExit();
+ }
+}
+
+TEST_CASE("httpclient.password")
+{
+ using namespace std::literals;
+
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_filter");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_forbid");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ SUBCASE("usernamepassword")
+ {
+ CbObjectWriter Writer;
+ {
+ Writer.BeginObject("basic");
+ {
+ Writer << "username"sv
+ << "me";
+ Writer << "password"sv
+ << "456123789";
+ }
+ Writer.EndObject();
+ Writer << "protect-machine-local-requests" << true;
+ }
+
+ PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save());
+
+ PasswordHttpFilter MyFilter(PasswordFilterOptions);
+
+ AsioServer->SetHttpRequestFilter(&MyFilter);
+
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response ForbiddenResponse = Client.Get("/test/yo");
+ CHECK(!ForbiddenResponse.IsSuccess());
+ CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden);
+
+ HttpClient::Response WithBasicResponse =
+ Client.Get("/test/yo",
+ std::pair<std::string, std::string>("Authorization",
+ fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password)));
+ CHECK(WithBasicResponse.IsSuccess());
+ AsioServer->SetHttpRequestFilter(nullptr);
+ }
+ AsioServer->RequestExit();
+ }
}
+TEST_SUITE_END();
void
httpclient_forcelink()
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
new file mode 100644
index 000000000..52bf149a7
--- /dev/null
+++ b/src/zenhttp/httpclient_test.cpp
@@ -0,0 +1,1366 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpserver.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinaryutil.h>
+# include <zencore/compositebuffer.h>
+# include <zencore/iobuffer.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/session.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include "servers/httpasio.h"
+
+# include <atomic>
+# include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+// Test service
+
+class HttpClientTestService : public HttpService
+{
+public:
+ HttpClientTestService()
+ {
+ m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool {
+ for (char C : Str)
+ {
+ if (C < '0' || C > '9')
+ {
+ return false;
+ }
+ }
+ return !Str.empty();
+ });
+
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ IoBuffer Body = HttpReq.ReadPayload();
+ HttpContentType CT = HttpReq.RequestContentType();
+ HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body);
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "echo/headers",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Auth = HttpReq.GetAuthorizationHeader();
+ CbObjectWriter Writer;
+ if (!Auth.empty())
+ {
+ Writer.AddString("Authorization", Auth);
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "echo/method",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Method = ToString(HttpReq.RequestVerb());
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "json",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Obj.AddString("message", "test");
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "nocontent",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "created",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created");
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "content-type/text",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/json",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}");
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/binary",
+ [](HttpRouterRequest& Req) {
+ uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF};
+ IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data));
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/cbobject",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddString("type", "cbobject");
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "auth/bearer",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Auth = HttpReq.GetAuthorizationHeader();
+ if (Auth.starts_with("Bearer ") && Auth.size() > 7)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated");
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized");
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "slow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Sleep(2000);
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response");
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "large",
+ [](HttpRouterRequest& Req) {
+ constexpr size_t Size = 64 * 1024;
+ IoBuffer Buf(Size);
+ uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData());
+ for (size_t i = 0; i < Size; ++i)
+ {
+ Ptr[i] = static_cast<uint8_t>(i & 0xFF);
+ }
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "status/{statuscode}",
+ [](HttpRouterRequest& Req) {
+ std::string_view CodeStr = Req.GetCapture(1);
+ int Code = std::stoi(std::string{CodeStr});
+ const HttpResponseCode ResponseCode = static_cast<HttpResponseCode>(Code);
+ Req.ServerRequest().WriteResponse(ResponseCode);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "attempt-counter",
+ [this](HttpRouterRequest& Req) {
+ uint32_t Count = m_AttemptCounter.fetch_add(1);
+ if (Count < m_FailCount)
+ {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable);
+ }
+ else
+ {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries");
+ }
+ },
+ HttpVerb::kGet);
+ }
+
+ virtual const char* BaseUri() const override { return "/api/test/"; }
+ virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); }
+
+ void ResetAttemptCounter(uint32_t FailCount)
+ {
+ m_AttemptCounter.store(0);
+ m_FailCount = FailCount;
+ }
+
+private:
+ HttpRequestRouter m_Router;
+ std::atomic<uint32_t> m_AttemptCounter{0};
+ uint32_t m_FailCount = 2;
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Test server fixture
+
+struct TestServerFixture
+{
+ HttpClientTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ Ref<HttpServer> Server;
+ std::thread ServerThread;
+ int Port = -1;
+
+ TestServerFixture()
+ {
+ Server = CreateHttpAsioServer(AsioConfig{});
+ Port = Server->Initialize(7600, TmpDir.Path());
+ ZEN_ASSERT(Port != -1);
+ Server->RegisterService(TestService);
+ ServerThread = std::thread([this]() { Server->Run(false); });
+ }
+
+ ~TestServerFixture()
+ {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ }
+
+ HttpClient MakeClient(HttpClientSettings Settings = {})
+ {
+ return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {});
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+TEST_SUITE_BEGIN("http.httpclient");
+
+TEST_CASE("httpclient.verbs")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("GET returns 200 with expected body")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "GET");
+ }
+
+ SUBCASE("POST dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "POST");
+ }
+
+ SUBCASE("PUT dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Put("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "PUT");
+ }
+
+ SUBCASE("DELETE dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Delete("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "DELETE");
+ }
+
+ SUBCASE("HEAD returns 200 with empty body")
+ {
+ HttpClient::Response Resp = Client.Head("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), ""sv);
+ }
+}
+
+TEST_CASE("httpclient.get")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("simple GET with text response")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("GET with auth header via echo")
+ {
+ HttpClient::Response Resp =
+ Client.Get("/api/test/echo/headers", std::pair<std::string, std::string>("Authorization", "Bearer test-token-123"));
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123");
+ }
+
+ SUBCASE("GET returning CbObject")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK(Obj["ok"].AsBool() == true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("GET large payload")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/large");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+
+ const uint8_t* Data = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData());
+ bool Valid = true;
+ for (size_t i = 0; i < 64 * 1024; ++i)
+ {
+ if (Data[i] != static_cast<uint8_t>(i & 0xFF))
+ {
+ Valid = false;
+ break;
+ }
+ }
+ CHECK(Valid);
+ }
+}
+
+TEST_CASE("httpclient.post")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("POST with IoBuffer payload echo round-trip")
+ {
+ const char* Payload = "test payload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "test payload data");
+ }
+
+ SUBCASE("POST with IoBuffer and explicit content type")
+ {
+ const char* Payload = "{\"key\":\"value\"}";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}");
+ }
+
+ SUBCASE("POST with CbObject payload round-trip")
+ {
+ CbObjectWriter Writer;
+ Writer.AddBool("enabled", true);
+ Writer.AddString("name", "testobj");
+ CbObject Obj = Writer.Save();
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Obj);
+ CHECK(Resp.IsSuccess());
+ CbObject RoundTripped = Resp.AsObject();
+ CHECK(RoundTripped["enabled"].AsBool() == true);
+ CHECK_EQ(RoundTripped["name"].AsString(), "testobj");
+ }
+
+ SUBCASE("POST with CompositeBuffer payload")
+ {
+ const char* Part1 = "hello ";
+ const char* Part2 = "composite";
+ IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1));
+ IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2));
+
+ SharedBuffer Seg1{Buf1};
+ SharedBuffer Seg2{Buf2};
+ CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)};
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello composite");
+ }
+
+ SUBCASE("POST with custom headers")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{});
+ CHECK(Resp.IsSuccess());
+ }
+
+ SUBCASE("POST with empty body to nocontent endpoint")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/nocontent");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+}
+
+TEST_CASE("httpclient.put")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("PUT with IoBuffer payload echo round-trip")
+ {
+ const char* Payload = "put payload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Put("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "put payload data");
+ }
+
+ SUBCASE("PUT with parameters only")
+ {
+ HttpClient::Response Resp = Client.Put("/api/test/nocontent");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+
+ SUBCASE("PUT to created endpoint")
+ {
+ const char* Payload = "new resource";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Put("/api/test/created", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created);
+ CHECK_EQ(Resp.AsText(), "resource created");
+ }
+}
+
+TEST_CASE("httpclient.upload")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("Upload IoBuffer")
+ {
+ constexpr size_t Size = 128 * 1024;
+ IoBuffer Blob = CreateSemiRandomBlob(Size);
+
+ HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), Size);
+ }
+
+ SUBCASE("Upload CompositeBuffer")
+ {
+ IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024);
+ IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024);
+
+ SharedBuffer Seg1{Buf1};
+ SharedBuffer Seg2{Buf2};
+ CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)};
+
+ HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+TEST_CASE("httpclient.download")
+{
+ TestServerFixture Fixture;
+ ScopedTemporaryDirectory DownloadDir;
+
+ SUBCASE("Download small payload stays in memory")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill")
+ {
+ HttpClientSettings Settings;
+ Settings.MaximumInMemoryDownloadSize = 4;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+TEST_CASE("httpclient.status-codes")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("2xx are success")
+ {
+ CHECK(Client.Get("/api/test/status/200").IsSuccess());
+ CHECK(Client.Get("/api/test/status/201").IsSuccess());
+ CHECK(Client.Get("/api/test/status/204").IsSuccess());
+ }
+
+ SUBCASE("4xx are not success")
+ {
+ CHECK(!Client.Get("/api/test/status/400").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/401").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/403").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/404").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/409").IsSuccess());
+ }
+
+ SUBCASE("5xx are not success")
+ {
+ CHECK(!Client.Get("/api/test/status/500").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/502").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/503").IsSuccess());
+ }
+
+ SUBCASE("status code values match")
+ {
+ CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created);
+ CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent);
+ CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest);
+ CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized);
+ CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden);
+ CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound);
+ CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict);
+ CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError);
+ CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway);
+ CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable);
+ }
+}
+
+TEST_CASE("httpclient.response")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("IsSuccess and operator bool for success")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(static_cast<bool>(Resp));
+ }
+
+ SUBCASE("IsSuccess and operator bool for failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/404");
+ CHECK(!Resp.IsSuccess());
+ CHECK(!static_cast<bool>(Resp));
+ }
+
+ SUBCASE("AsText returns body")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("AsText returns empty for no-content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/nocontent");
+ CHECK(Resp.AsText().empty());
+ }
+
+ SUBCASE("AsObject parses CbObject")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CbObject Obj = Resp.AsObject();
+ CHECK(Obj["ok"].AsBool() == true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("AsObject returns empty for non-CB content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CbObject Obj = Resp.AsObject();
+ CHECK(!Obj);
+ }
+
+ SUBCASE("ToText for text content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/text");
+ CHECK_EQ(Resp.ToText(), "plain text");
+ }
+
+ SUBCASE("ToText for CbObject content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ std::string Text = Resp.ToText();
+ CHECK(!Text.empty());
+ // ToText for CbObject converts to JSON string representation
+ CHECK(Text.find("ok") != std::string::npos);
+ CHECK(Text.find("test") != std::string::npos);
+ }
+
+ SUBCASE("ErrorMessage includes status code on failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/404");
+ std::string Msg = Resp.ErrorMessage("test-prefix");
+ CHECK(Msg.find("test-prefix") != std::string::npos);
+ CHECK(Msg.find("404") != std::string::npos);
+ }
+
+ SUBCASE("ThrowError throws on failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/500");
+ CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError);
+ }
+
+ SUBCASE("ThrowError does not throw on success")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK_NOTHROW(Resp.ThrowError("test"));
+ }
+
+ SUBCASE("HttpClientError carries response code")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/403");
+ try
+ {
+ Resp.ThrowError("test");
+ CHECK(false); // should not reach
+ }
+ catch (const HttpClientError& Err)
+ {
+ CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden);
+ }
+ }
+}
+
+TEST_CASE("httpclient.error-handling")
+{
+ SUBCASE("Connection refused")
+ {
+ HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("Request timeout")
+ {
+ TestServerFixture Fixture;
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(500);
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/slow");
+ CHECK(!Resp.IsSuccess());
+ }
+
+ SUBCASE("Nonexistent endpoint returns failure")
+ {
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Get("/api/test/does-not-exist");
+ CHECK(!Resp.IsSuccess());
+ }
+}
+
+TEST_CASE("httpclient.session")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Default session ID is non-empty")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ CHECK(!Client.GetSessionId().empty());
+ }
+
+ SUBCASE("SetSessionId changes ID")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ Oid NewId = Oid::NewOid();
+ std::string OldId = std::string(Client.GetSessionId());
+ Client.SetSessionId(NewId);
+ CHECK_EQ(Client.GetSessionId(), NewId.ToString());
+ CHECK_NE(Client.GetSessionId(), OldId);
+ }
+
+ SUBCASE("SetSessionId with Zero resets")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ Oid NewId = Oid::NewOid();
+ Client.SetSessionId(NewId);
+ CHECK_EQ(Client.GetSessionId(), NewId.ToString());
+ Client.SetSessionId(Oid::Zero);
+ // After resetting, should get a session string (not empty, not the custom one)
+ CHECK(!Client.GetSessionId().empty());
+ CHECK_NE(Client.GetSessionId(), NewId.ToString());
+ }
+}
+
+TEST_CASE("httpclient.authentication")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Authenticate returns false without provider")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ CHECK(!Client.Authenticate());
+ }
+
+ SUBCASE("Authenticate returns true with valid token")
+ {
+ HttpClientSettings Settings;
+ Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
+ return HttpClientAccessToken{
+ .Value = "valid-token",
+ .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1),
+ };
+ };
+ HttpClient Client = Fixture.MakeClient(Settings);
+ CHECK(Client.Authenticate());
+ }
+
+ SUBCASE("Authenticate returns false with expired token")
+ {
+ HttpClientSettings Settings;
+ Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
+ return HttpClientAccessToken{
+ .Value = "expired-token",
+ .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1),
+ };
+ };
+ HttpClient Client = Fixture.MakeClient(Settings);
+ CHECK(!Client.Authenticate());
+ }
+
+ SUBCASE("Bearer token verified by auth endpoint")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response AuthResp =
+ Client.Get("/api/test/auth/bearer", std::pair<std::string, std::string>("Authorization", "Bearer my-secret-token"));
+ CHECK(AuthResp.IsSuccess());
+ CHECK_EQ(AuthResp.AsText(), "authenticated");
+ }
+
+ SUBCASE("Request without token to auth endpoint gets 401")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Get("/api/test/auth/bearer");
+ CHECK(!Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized);
+ }
+}
+
+TEST_CASE("httpclient.content-types")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("text content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/text");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText);
+ }
+
+ SUBCASE("JSON content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/json");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON);
+ }
+
+ SUBCASE("binary content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/binary");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary);
+ }
+
+ SUBCASE("CbObject content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject);
+ }
+}
+
+TEST_CASE("httpclient.metadata")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("ElapsedSeconds is positive")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.ElapsedSeconds > 0.0);
+ }
+
+ SUBCASE("DownloadedBytes populated for GET")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.DownloadedBytes > 0);
+ }
+
+ SUBCASE("UploadedBytes populated for POST with payload")
+ {
+ const char* Payload = "some upload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.UploadedBytes > 0);
+ }
+}
+
+TEST_CASE("httpclient.retry")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Retry succeeds after transient failures")
+ {
+ Fixture.TestService.ResetAttemptCounter(2);
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/attempt-counter");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "success after retries");
+ }
+
+ SUBCASE("No retry returns 503 immediately")
+ {
+ Fixture.TestService.ResetAttemptCounter(2);
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 0;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/attempt-counter");
+ CHECK(!Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable);
+ }
+}
+
+TEST_CASE("httpclient.measurelatency")
+{
+ SUBCASE("Successful measurement against live server")
+ {
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
+ CHECK(Result.Success);
+ CHECK(Result.LatencySeconds > 0.0);
+ }
+
+ SUBCASE("Failed measurement against unreachable port")
+ {
+ HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
+ CHECK(!Result.Success);
+ CHECK(!Result.FailureReason.empty());
+ }
+}
+
+TEST_CASE("httpclient.keyvaluemap")
+{
+ SUBCASE("Default construction is empty")
+ {
+ HttpClient::KeyValueMap Map;
+ CHECK(Map->empty());
+ }
+
+ SUBCASE("Construction from pair")
+ {
+ HttpClient::KeyValueMap Map(std::pair<std::string, std::string>("key", "value"));
+ CHECK_EQ(Map->size(), 1u);
+ CHECK_EQ(Map->at("key"), "value");
+ }
+
+ SUBCASE("Construction from string_view pair")
+ {
+ HttpClient::KeyValueMap Map(std::pair<std::string_view, std::string_view>("key"sv, "value"sv));
+ CHECK_EQ(Map->size(), 1u);
+ CHECK_EQ(Map->at("key"), "value");
+ }
+
+ SUBCASE("Construction from initializer list")
+ {
+ HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}});
+ CHECK_EQ(Map->size(), 2u);
+ CHECK_EQ(Map->at("a"), "1");
+ CHECK_EQ(Map->at("b"), "2");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Transport fault testing
+
+static std::string
+MakeRawHttpResponse(int StatusCode, std::string_view Body)
+{
+ return fmt::format(
+ "HTTP/1.1 {} OK\r\n"
+ "Content-Type: text/plain\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ StatusCode,
+ Body.size(),
+ Body);
+}
+
+static std::string
+MakeRawHttpHeaders(int StatusCode, size_t ContentLength)
+{
+ return fmt::format(
+ "HTTP/1.1 {} OK\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n",
+ StatusCode,
+ ContentLength);
+}
+
+static void
+DrainHttpRequest(asio::ip::tcp::socket& Socket)
+{
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+}
+
+static void
+DrainFullHttpRequest(asio::ip::tcp::socket& Socket)
+{
+ // Read until end of headers
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+ if (Ec)
+ {
+ return;
+ }
+
+ // Extract headers to find Content-Length
+ std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data()));
+
+ size_t ContentLength = 0;
+ auto Pos = Headers.find("Content-Length: ");
+ if (Pos == std::string::npos)
+ {
+ Pos = Headers.find("content-length: ");
+ }
+ if (Pos != std::string::npos)
+ {
+ size_t ValStart = Pos + 16; // length of "Content-Length: "
+ size_t ValEnd = Headers.find("\r\n", ValStart);
+ if (ValEnd != std::string::npos)
+ {
+ ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart));
+ }
+ }
+
+ // Calculate how many body bytes were already read past the header boundary.
+ // asio::read_until may read past the delimiter, so Buf.data() contains everything read.
+ size_t HeaderEnd = Headers.find("\r\n\r\n") + 4;
+ size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0;
+ size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0;
+
+ if (Remaining > 0)
+ {
+ std::vector<char> BodyBuf(Remaining);
+ asio::read(Socket, asio::buffer(BodyBuf), Ec);
+ }
+}
+
+static void
+DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead)
+{
+ // Read headers first
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+ if (Ec)
+ {
+ return;
+ }
+
+ // Determine how many body bytes were already buffered past headers
+ std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data()));
+ size_t HeaderEnd = All.find("\r\n\r\n") + 4;
+ size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0;
+
+ if (BodyBytesInBuf < BytesToRead)
+ {
+ size_t Remaining = BytesToRead - BodyBytesInBuf;
+ std::vector<char> BodyBuf(Remaining);
+ asio::read(Socket, asio::buffer(BodyBuf), Ec);
+ }
+}
+
+struct FaultTcpServer
+{
+ using FaultHandler = std::function<void(asio::ip::tcp::socket&)>;
+
+ asio::io_context m_IoContext;
+ asio::ip::tcp::acceptor m_Acceptor;
+ FaultHandler m_Handler;
+ std::thread m_Thread;
+ int m_Port;
+
+ explicit FaultTcpServer(FaultHandler Handler)
+ : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0))
+ , m_Handler(std::move(Handler))
+ {
+ m_Port = m_Acceptor.local_endpoint().port();
+ StartAccept();
+ m_Thread = std::thread([this]() { m_IoContext.run(); });
+ }
+
+ ~FaultTcpServer()
+ {
+ std::error_code Ec;
+ m_Acceptor.close(Ec);
+ m_IoContext.stop();
+ if (m_Thread.joinable())
+ {
+ m_Thread.join();
+ }
+ }
+
+ FaultTcpServer(const FaultTcpServer&) = delete;
+ FaultTcpServer& operator=(const FaultTcpServer&) = delete;
+
+ void StartAccept()
+ {
+ m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) {
+ if (!Ec)
+ {
+ m_Handler(Socket);
+ }
+ if (m_Acceptor.is_open())
+ {
+ StartAccept();
+ }
+ });
+ }
+
+ HttpClient MakeClient(HttpClientSettings Settings = {})
+ {
+ return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {});
+ }
+};
+
+TEST_CASE("httpclient.transport-faults" * doctest::skip())
+{
+ SUBCASE("connection reset before response")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("connection closed before response")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("partial headers then close")
+ {
+ // libcurl parses the status line (200 OK) and accepts the response even though
+ // headers are truncated mid-field. It reports success with an empty body instead
+ // of an error. Ideally this should be detected as a transport failure.
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Partial = "HTTP/1.1 200 OK\r\nContent-";
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Partial), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ WARN(!Resp.IsSuccess());
+ WARN(Resp.Error.has_value());
+ }
+
+ SUBCASE("truncated body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 1000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ std::string PartialBody(100, 'x');
+ asio::write(Socket, asio::buffer(PartialBody), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("connection reset mid-body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 10000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ std::string PartialBody(1000, 'x');
+ asio::write(Socket, asio::buffer(PartialBody), Ec);
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("stalled response triggers timeout")
+ {
+ std::atomic<bool> StallActive{true};
+ FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 1000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ while (StallActive.load())
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(500);
+ HttpClient Client = Server.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ StallActive.store(false);
+ }
+
+ SUBCASE("retry succeeds after transient failures")
+ {
+ std::atomic<int> ConnCount{0};
+ FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) {
+ int N = ConnCount.fetch_add(1);
+ DrainHttpRequest(Socket);
+ if (N < 2)
+ {
+ // Connection reset produces NETWORK_SEND_FAILURE which is retryable
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ }
+ else
+ {
+ std::string Response = MakeRawHttpResponse(200, "recovered");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Server.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "recovered");
+ }
+}
+
+TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
+{
+ constexpr size_t kPostBodySize = 256 * 1024;
+
+ auto MakePostBody = []() -> IoBuffer {
+ IoBuffer Buf(kPostBodySize);
+ uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData());
+ for (size_t i = 0; i < kPostBodySize; ++i)
+ {
+ Ptr[i] = static_cast<uint8_t>(i & 0xFF);
+ }
+ Buf.SetContentType(ZenContentType::kBinary);
+ return Buf;
+ };
+
+ SUBCASE("POST: server resets before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: server closes before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: server resets mid-body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainPartialBody(Socket, 8 * 1024);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: early error response before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(503, "service busy");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ // With a large upload body, the server may RST the connection before the client
+ // reads the 503 response. Either outcome is valid: the client sees the HTTP 503
+ // status, or it sees a transport-level error from the RST.
+ CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value()));
+ }
+
+ SUBCASE("POST: stalled upload triggers timeout")
+ {
+ std::atomic<bool> StallActive{true};
+ FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ // Stop reading body — TCP window will fill and client send will stall
+ while (StallActive.load())
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(2000);
+ HttpClient Client = Server.MakeClient(Settings);
+
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ StallActive.store(false);
+ }
+
+ SUBCASE("POST: retry with large body after transient failure")
+ {
+ std::atomic<int> ConnCount{0};
+ FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) {
+ int N = ConnCount.fetch_add(1);
+ if (N < 2)
+ {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ }
+ else
+ {
+ DrainFullHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(200, "upload-ok");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Server.MakeClient(Settings);
+
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "upload-ok");
+ }
+}
+
+TEST_SUITE_END();
+
+void
+httpclient_test_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
index 72df12d02..02e1b57e2 100644
--- a/src/zenhttp/httpclientauth.cpp
+++ b/src/zenhttp/httpclientauth.cpp
@@ -170,7 +170,7 @@ namespace zen { namespace httpclientauth {
time_t UTCTime = timegm(&Time);
HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime);
- ExpireTime += std::chrono::microseconds(Millisecond);
+ ExpireTime += std::chrono::milliseconds(Millisecond);
return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime};
}
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index c4e67d4ed..9bae95690 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -23,10 +23,12 @@
#include <zencore/logging.h>
#include <zencore/stream.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
#include <zenhttp/packageformat.h>
#include <zentelemetry/otlptrace.h>
+#include <zentelemetry/stats.h>
#include <charconv>
#include <mutex>
@@ -463,7 +465,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
//////////////////////////////////////////////////////////////////////////
-HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri())
+HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service)
{
}
@@ -745,6 +747,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand
{
if (UriPattern[i] == '}')
{
+ if (i == PatternStart)
+ {
+ throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern));
+ }
std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart);
if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end())
{
@@ -910,8 +916,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
CapturedSegments.emplace_back(Uri);
- for (int MatcherIndex : Matchers)
+ for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++)
{
+ int MatcherIndex = Matchers[MatcherOffset];
if (UriPos >= UriLen)
{
IsMatch = false;
@@ -921,9 +928,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (MatcherIndex < 0)
{
// Literal match
- int LitIndex = -MatcherIndex - 1;
- const std::string& LitStr = m_Literals[LitIndex];
- size_t LitLen = LitStr.length();
+ int LitIndex = -MatcherIndex - 1;
+ std::string_view LitStr = m_Literals[LitIndex];
+ size_t LitLen = LitStr.length();
if (Uri.substr(UriPos, LitLen) == LitStr)
{
@@ -939,9 +946,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
{
// Matcher function
size_t SegmentStart = UriPos;
- while (UriPos < UriLen && Uri[UriPos] != '/')
+
+ if (MatcherOffset == (Matchers.size() - 1))
+ {
+ // Last matcher, use the remaining part of the uri
+ UriPos = UriLen;
+ }
+ else
{
- ++UriPos;
+ while (UriPos < UriLen && Uri[UriPos] != '/')
+ {
+ ++UriPos;
+ }
}
std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart);
@@ -970,7 +986,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan())
{
ExtendableStringBuilder<128> RoutePath;
- RoutePath.Append(Request.BaseUri());
+ RoutePath.Append(Request.Service().BaseUri());
RoutePath.Append(Handler.Pattern);
ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView());
}
@@ -994,7 +1010,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan())
{
ExtendableStringBuilder<128> RoutePath;
- RoutePath.Append(Request.BaseUri());
+ RoutePath.Append(Request.Service().BaseUri());
RoutePath.Append(Handler.Pattern);
ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView());
}
@@ -1014,7 +1030,28 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
int
HttpServer::Initialize(int BasePort, std::filesystem::path DataDir)
{
- return OnInitialize(BasePort, std::move(DataDir));
+ m_EffectivePort = OnInitialize(BasePort, std::move(DataDir));
+ m_ExternalHost = OnGetExternalHost();
+ return m_EffectivePort;
+}
+
+std::string
+HttpServer::OnGetExternalHost() const
+{
+ return GetMachineName();
+}
+
+std::string
+HttpServer::GetServiceUri(const HttpService* Service) const
+{
+ if (Service)
+ {
+ return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri());
+ }
+ else
+ {
+ return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort);
+ }
}
void
@@ -1052,6 +1089,45 @@ HttpServer::EnumerateServices(std::function<void(HttpService& Service)>&& Callba
}
}
+void
+HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ OnSetHttpRequestFilter(RequestFilter);
+}
+
+CbObject
+HttpServer::CollectStats()
+{
+ CbObjectWriter Cbo;
+
+ metrics::EmitSnapshot("requests", m_RequestMeter, Cbo);
+
+ Cbo.BeginObject("bytes");
+ {
+ Cbo << "received" << GetTotalBytesReceived();
+ Cbo << "sent" << GetTotalBytesSent();
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("websockets");
+ {
+ Cbo << "active_connections" << GetActiveWebSocketConnectionCount();
+ Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed);
+ Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed);
+ Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed);
+ Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed);
+ }
+ Cbo.EndObject();
+
+ return Cbo.Save();
+}
+
+void
+HttpServer::HandleStatsRequest(HttpServerRequest& Request)
+{
+ Request.WriteResponse(HttpResponseCode::OK, CollectStats());
+}
+
//////////////////////////////////////////////////////////////////////////
HttpRpcHandler::HttpRpcHandler()
@@ -1294,6 +1370,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.httpserver");
+
TEST_CASE("http.common")
{
using namespace std::literals;
@@ -1310,7 +1388,11 @@ TEST_CASE("http.common")
{
TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; }
virtual IoBuffer ReadPayload() override { return IoBuffer(); }
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override
+
+ virtual bool IsLocalMachineRequest() const override { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override { return {}; }
+
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override
{
ZEN_UNUSED(ResponseCode, ContentType, Blobs);
}
@@ -1395,20 +1477,33 @@ TEST_CASE("http.common")
SUBCASE("router-matcher")
{
- bool HandledA = false;
- bool HandledAA = false;
- bool HandledAB = false;
- bool HandledAandB = false;
+ bool HandledA = false;
+ bool HandledAA = false;
+ bool HandledAB = false;
+ bool HandledAandB = false;
+ bool HandledAandPath = false;
std::vector<std::string> Captures;
auto Reset = [&] {
- HandledA = HandledAA = HandledAB = HandledAandB = false;
+ HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false;
Captures.clear();
};
TestHttpService Service;
HttpRequestRouter r;
- r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; });
- r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; });
+
+ r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; });
+ r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; });
+ static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); });
+
+ r.RegisterRoute(
+ "path/{a}/{path}",
+ [&](auto& Req) {
+ HandledAandPath = true;
+ Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))};
+ },
+ HttpVerb::kGet);
+
r.RegisterRoute(
"{a}",
[&](auto& Req) {
@@ -1437,7 +1532,6 @@ TEST_CASE("http.common")
Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))};
},
HttpVerb::kGet);
-
{
Reset();
TestHttpServerRequest req{Service, "ab"sv};
@@ -1445,6 +1539,7 @@ TEST_CASE("http.common")
CHECK(HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 1);
CHECK_EQ(Captures[0], "ab"sv);
@@ -1457,6 +1552,7 @@ TEST_CASE("http.common")
CHECK(!HandledA);
CHECK(!HandledAA);
CHECK(HandledAB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 2);
CHECK_EQ(Captures[0], "ab"sv);
CHECK_EQ(Captures[1], "def"sv);
@@ -1470,6 +1566,7 @@ TEST_CASE("http.common")
CHECK(!HandledAA);
CHECK(!HandledAB);
CHECK(HandledAandB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 2);
CHECK_EQ(Captures[0], "ab"sv);
CHECK_EQ(Captures[1], "def"sv);
@@ -1482,6 +1579,7 @@ TEST_CASE("http.common")
CHECK(!HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
}
{
@@ -1491,6 +1589,35 @@ TEST_CASE("http.common")
CHECK(HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 1);
+ CHECK_EQ(Captures[0], "a123"sv);
+ }
+
+ {
+ Reset();
+ TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv};
+ r.HandleRequest(req);
+ CHECK(!HandledA);
+ CHECK(!HandledAA);
+ CHECK(!HandledAB);
+ CHECK(HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 2);
+ CHECK_EQ(Captures[0], "ab"sv);
+ CHECK_EQ(Captures[1], "simple_path.txt"sv);
+ }
+
+ {
+ Reset();
+ TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv};
+ r.HandleRequest(req);
+ CHECK(!HandledA);
+ CHECK(!HandledAA);
+ CHECK(!HandledAB);
+ CHECK(HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 2);
+ CHECK_EQ(Captures[0], "ab"sv);
+ CHECK_EQ(Captures[1], "directory/and/path.txt"sv);
}
}
@@ -1508,6 +1635,8 @@ TEST_CASE("http.common")
}
}
+TEST_SUITE_END();
+
void
http_forcelink()
{
diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h
index a988346e0..c252a5d99 100644
--- a/src/zenhttp/include/zenhttp/cprutils.h
+++ b/src/zenhttp/include/zenhttp/cprutils.h
@@ -66,10 +66,10 @@ struct fmt::formatter<cpr::Response>
Response.url.str(),
Response.status_code,
zen::ToString(zen::HttpResponseCode(Response.status_code)),
+ Response.reason,
Response.uploaded_bytes,
Response.downloaded_bytes,
NiceResponseTime.c_str(),
- Response.reason,
Json);
}
else
@@ -82,10 +82,10 @@ struct fmt::formatter<cpr::Response>
Response.url.str(),
Response.status_code,
zen::ToString(zen::HttpResponseCode(Response.status_code)),
+ Response.reason,
Response.uploaded_bytes,
Response.downloaded_bytes,
NiceResponseTime.c_str(),
- Response.reason,
Body.GetText());
}
}
diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h
index addb00cb8..57ab01158 100644
--- a/src/zenhttp/include/zenhttp/formatters.h
+++ b/src/zenhttp/include/zenhttp/formatters.h
@@ -73,7 +73,7 @@ struct fmt::formatter<zen::HttpClient::Response>
if (Response.IsSuccess())
{
return fmt::format_to(Ctx.out(),
- "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s",
+ "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}",
ToString(Response.StatusCode),
Response.UploadedBytes,
Response.DownloadedBytes,
diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h
index 0270973bf..2d384d1d8 100644
--- a/src/zenhttp/include/zenhttp/httpapiservice.h
+++ b/src/zenhttp/include/zenhttp/httpapiservice.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include <zenhttp/httpserver.h>
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index 9a9b74d72..1bb36a298 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -13,6 +13,7 @@
#include <functional>
#include <optional>
#include <unordered_map>
+#include <vector>
namespace zen {
@@ -58,6 +59,10 @@ struct HttpClientSettings
Oid SessionId = Oid::Zero;
bool Verbose = false;
uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u;
+
+ /// HTTP status codes that are expected and should not be logged as warnings.
+ /// 404 is always treated as expected regardless of this list.
+ std::vector<HttpResponseCode> ExpectedErrorCodes;
};
class HttpClientError : public std::runtime_error
@@ -113,6 +118,15 @@ private:
class HttpClientBase;
+/** HTTP Client
+ *
+ * This is safe for use on multiple threads simultaneously, as each
+ * instance maintains an internal connection pool and will synchronize
+ * access to it as needed.
+ *
+ * Uses libcurl under the hood. We currently only use HTTP 1.1 features.
+ *
+ */
class HttpClient
{
public:
@@ -123,8 +137,11 @@ public:
struct ErrorContext
{
- int ErrorCode;
+ int ErrorCode = 0;
std::string ErrorMessage;
+
+ /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */
+ bool IsConnectionError() const;
};
struct KeyValueMap
@@ -171,13 +188,29 @@ public:
KeyValueMap Header;
// The number of bytes sent as part of the request
- int64_t UploadedBytes;
+ int64_t UploadedBytes = 0;
// The number of bytes received as part of the response
- int64_t DownloadedBytes;
+ int64_t DownloadedBytes = 0;
// The elapsed time in seconds for the request to execute
- double ElapsedSeconds;
+ double ElapsedSeconds = 0.0;
+
+ struct MultipartBoundary
+ {
+ uint64_t OffsetInPayload = 0;
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType;
+ };
+
+ // Ranges will map out all received ranges, both single and multi-range responses
+ // If no range was requested Ranges will be empty
+ std::vector<MultipartBoundary> Ranges;
+
+ // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges).
+ // If the response was not a partial response, an empty vector will be returned
+ std::vector<std::pair<uint64_t, uint64_t>> GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const;
// This contains any errors from the HTTP stack. It won't contain information on
// why the server responded with a non-success HTTP status, that may be gleaned
@@ -260,6 +293,16 @@ private:
const HttpClientSettings m_ConnectionSettings;
};
-void httpclient_forcelink(); // internal
+struct LatencyTestResult
+{
+ bool Success = false;
+ std::string FailureReason;
+ double LatencySeconds = -1.0;
+};
+
+LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url);
+
+void httpclient_forcelink(); // internal
+void httpclient_test_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
index bc18549c9..8fca35ac5 100644
--- a/src/zenhttp/include/zenhttp/httpcommon.h
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept
return IsHttpSuccessCode(int(HttpCode));
}
+[[nodiscard]] inline bool
+IsHttpOk(HttpResponseCode HttpCode) noexcept
+{
+ return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted ||
+ HttpCode == HttpResponseCode::NoContent;
+}
+
std::string_view ToString(HttpResponseCode HttpCode);
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 3438a1471..0e1714669 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -13,6 +13,8 @@
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <zentelemetry/stats.h>
+
#include <functional>
#include <gsl/gsl-lite.hpp>
#include <list>
@@ -30,16 +32,18 @@ class HttpService;
*/
class HttpServerRequest
{
-public:
+protected:
explicit HttpServerRequest(HttpService& Service);
+
+public:
~HttpServerRequest();
// Synchronous operations
[[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix
- [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
+ [[nodiscard]] inline std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
[[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; }
- [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix
+ [[nodiscard]] inline HttpService& Service() const { return m_Service; }
struct QueryParams
{
@@ -79,6 +83,18 @@ public:
inline bool IsHandled() const { return !!(m_Flags & kIsHandled); }
inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); }
inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; }
+ inline void SetLogRequest(bool ShouldLog)
+ {
+ if (ShouldLog)
+ {
+ m_Flags |= kLogRequest;
+ }
+ else
+ {
+ m_Flags &= ~kLogRequest;
+ }
+ }
+ inline bool ShouldLogRequest() const { return !!(m_Flags & kLogRequest); }
/** Read POST/PUT payload for request body, which is always available without delay
*/
@@ -87,6 +103,10 @@ public:
CbObject ReadPayloadObject();
CbPackage ReadPayloadPackage();
+ virtual bool IsLocalMachineRequest() const = 0;
+ virtual std::string_view GetAuthorizationHeader() const = 0;
+ virtual std::string_view GetRemoteAddress() const { return {}; }
+
/** Respond with payload
No data will have been sent when any of these functions return. Instead, the response will be transmitted
@@ -115,15 +135,17 @@ protected:
kSuppressBody = 1 << 1,
kHaveRequestId = 1 << 2,
kHaveSessionId = 1 << 3,
+ kLogRequest = 1 << 4,
};
- mutable uint32_t m_Flags = 0;
+ mutable uint32_t m_Flags = 0;
+
+ HttpService& m_Service; // Service handling this request
HttpVerb m_Verb = HttpVerb::kGet;
HttpContentType m_ContentType = HttpContentType::kBinary;
HttpContentType m_AcceptType = HttpContentType::kUnknownContentType;
uint64_t m_ContentLength = ~0ull;
- std::string_view m_BaseUri; // Base URI path of the service handling this request
- std::string_view m_Uri; // URI without service prefix
+ std::string_view m_Uri; // URI without service prefix
std::string_view m_UriWithExtension;
std::string_view m_QueryString;
mutable uint32_t m_RequestId = ~uint32_t(0);
@@ -144,6 +166,19 @@ public:
virtual void OnRequestComplete() = 0;
};
+class IHttpRequestFilter
+{
+public:
+ virtual ~IHttpRequestFilter() {}
+ enum class Result
+ {
+ Forbidden,
+ ResponseSent,
+ Accepted
+ };
+ virtual Result FilterRequest(HttpServerRequest& Request) = 0;
+};
+
/**
* Base class for implementing an HTTP "service"
*
@@ -170,30 +205,110 @@ private:
int m_UriPrefixLength = 0;
};
+struct IHttpStatsProvider
+{
+ /** Handle an HTTP stats request, writing the response directly.
+ * Implementations may inspect query parameters on the request
+ * to include optional detailed breakdowns.
+ */
+ virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
+
+ /** Return the provider's current stats as a CbObject snapshot.
+ * Used by the WebSocket push thread to broadcast live updates
+ * without requiring an HttpServerRequest. Providers that do
+ * not override this will be skipped in WebSocket broadcasts.
+ */
+ virtual CbObject CollectStats() { return {}; }
+};
+
+struct IHttpStatsService
+{
+ virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
+ virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
+};
+
/** HTTP server
*
* Implements the main event loop to service HTTP requests, and handles routing
* requests to the appropriate handler as registered via RegisterService
*/
-class HttpServer : public RefCounted
+class HttpServer : public RefCounted, public IHttpStatsProvider
{
public:
void RegisterService(HttpService& Service);
void EnumerateServices(std::function<void(HttpService&)>&& Callback);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
int Initialize(int BasePort, std::filesystem::path DataDir);
void Run(bool IsInteractiveSession);
void RequestExit();
void Close();
+ /** Returns a canonical http:// URI for the given service, using the external
+ * IP and the port the server is actually listening on. Only valid
+ * after Initialize() has returned successfully.
+ */
+ std::string GetServiceUri(const HttpService* Service) const;
+
+ /** Returns the external host string (IP or hostname) determined during Initialize().
+ * Only valid after Initialize() has returned successfully.
+ */
+ std::string_view GetExternalHost() const { return m_ExternalHost; }
+
+ /** Returns total bytes received and sent across all connections since server start. */
+ virtual uint64_t GetTotalBytesReceived() const { return 0; }
+ virtual uint64_t GetTotalBytesSent() const { return 0; }
+
+ /** Mark that a request has been handled. Called by server implementations. */
+ void MarkRequest() { m_RequestMeter.Mark(); }
+
+ /** Set a default redirect path for root requests */
+ void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; }
+
+ std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; }
+
+ /** Track active WebSocket connections — called by server implementations on upgrade/close. */
+ void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); }
+ void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); }
+ uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); }
+
+ /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */
+ void OnWebSocketFrameReceived(uint64_t Bytes)
+ {
+ m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed);
+ m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed);
+ }
+ void OnWebSocketFrameSent(uint64_t Bytes)
+ {
+ m_WsFramesSent.fetch_add(1, std::memory_order_relaxed);
+ m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed);
+ }
+
+ // IHttpStatsProvider
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+
private:
std::vector<HttpService*> m_KnownServices;
+ int m_EffectivePort = 0;
+ std::string m_ExternalHost;
+ metrics::Meter m_RequestMeter;
+ std::string m_DefaultRedirect;
+ std::atomic<uint64_t> m_ActiveWebSocketConnections{0};
+ std::atomic<uint64_t> m_WsFramesReceived{0};
+ std::atomic<uint64_t> m_WsFramesSent{0};
+ std::atomic<uint64_t> m_WsBytesReceived{0};
+ std::atomic<uint64_t> m_WsBytesSent{0};
virtual void OnRegisterService(HttpService& Service) = 0;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0;
virtual void OnRun(bool IsInteractiveSession) = 0;
virtual void OnRequestExit() = 0;
virtual void OnClose() = 0;
+
+protected:
+ virtual std::string OnGetExternalHost() const;
};
struct HttpServerPluginConfig
@@ -236,7 +351,7 @@ public:
inline HttpServerRequest& ServerRequest() { return m_HttpRequest; }
private:
- HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
+ explicit HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
~HttpRouterRequest() = default;
HttpRouterRequest(const HttpRouterRequest&) = delete;
@@ -385,7 +500,7 @@ public:
~HttpRpcHandler();
HttpRpcHandler(const HttpRpcHandler&) = delete;
- HttpRpcHandler operator=(const HttpRpcHandler&) = delete;
+ HttpRpcHandler& operator=(const HttpRpcHandler&) = delete;
void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction);
@@ -401,17 +516,7 @@ private:
bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
-struct IHttpStatsProvider
-{
- virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
-};
-
-struct IHttpStatsService
-{
- virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
- virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
-};
-
-void http_forcelink(); // internal
+void http_forcelink(); // internal
+void websocket_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h
index e6fea6765..460315faf 100644
--- a/src/zenhttp/include/zenhttp/httpstats.h
+++ b/src/zenhttp/include/zenhttp/httpstats.h
@@ -3,23 +3,50 @@
#pragma once
#include <zencore/logging.h>
+#include <zencore/thread.h>
#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+#include <atomic>
#include <map>
+#include <memory>
+#include <thread>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+#include <asio/steady_timer.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-class HttpStatsService : public HttpService, public IHttpStatsService
+class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler
{
public:
- HttpStatsService();
+ /// Construct without an io_context — optionally uses a dedicated push thread
+ /// for WebSocket stats broadcasting.
+ explicit HttpStatsService(bool EnableWebSockets = false);
+
+ /// Construct with an external io_context — uses an asio timer instead
+ /// of a dedicated thread for WebSocket stats broadcasting.
+ /// The caller must ensure the io_context outlives this service and that
+ /// its run loop is active.
+ HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true);
+
~HttpStatsService();
+ void Shutdown();
+
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
private:
LoggerRef m_Log;
HttpRequestRouter m_Router;
@@ -28,6 +55,22 @@ private:
RwLock m_Lock;
std::map<std::string, IHttpStatsProvider*> m_Providers;
+
+ // WebSocket push
+ RwLock m_WsConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::atomic<bool> m_PushEnabled{false};
+
+ void BroadcastStats();
+
+ // Thread-based push (when no io_context is provided)
+ std::thread m_PushThread;
+ Event m_PushEvent;
+ void PushThreadFunction();
+
+ // Timer-based push (when an io_context is provided)
+ std::unique_ptr<asio::steady_timer> m_PushTimer;
+ void EnqueuePushTimer();
};
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
new file mode 100644
index 000000000..926ec1e3d
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -0,0 +1,79 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/websocket.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+
+namespace zen {
+
+/**
+ * Callback interface for WebSocket client events
+ *
+ * Separate from the server-side IWebSocketHandler because the caller
+ * already owns the HttpWsClient — no Ref<WebSocketConnection> needed.
+ */
+class IWsClientHandler
+{
+public:
+ virtual ~IWsClientHandler() = default;
+
+ virtual void OnWsOpen() = 0;
+ virtual void OnWsMessage(const WebSocketMessage& Msg) = 0;
+ virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0;
+};
+
+struct HttpWsClientSettings
+{
+ std::string LogCategory = "wsclient";
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+};
+
+/**
+ * WebSocket client over TCP (ws:// scheme)
+ *
+ * Uses ASIO for async I/O. Two construction modes:
+ * - Internal io_context + background thread (standalone use)
+ * - External io_context (shared event loop, no internal thread)
+ *
+ * Thread-safe for SendText/SendBinary/Close.
+ */
+class HttpWsClient
+{
+public:
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {});
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {});
+
+ ~HttpWsClient();
+
+ HttpWsClient(const HttpWsClient&) = delete;
+ HttpWsClient& operator=(const HttpWsClient&) = delete;
+
+ void Connect();
+ void SendText(std::string_view Text);
+ void SendBinary(std::span<const uint8_t> Data);
+ void Close(uint16_t Code = 1000, std::string_view Reason = {});
+ bool IsOpen() const;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
index c90b840da..1a5068580 100644
--- a/src/zenhttp/include/zenhttp/packageformat.h
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -68,7 +68,7 @@ struct CbAttachmentEntry
struct CbAttachmentReferenceHeader
{
uint64_t PayloadByteOffset = 0;
- uint64_t PayloadByteSize = ~0u;
+ uint64_t PayloadByteSize = ~uint64_t(0);
uint16_t AbsolutePathLength = 0;
// This header will be followed by UTF8 encoded absolute path to backing file
diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h
new file mode 100644
index 000000000..6b2b548a6
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class PasswordSecurity
+{
+public:
+ struct Configuration
+ {
+ std::string Password;
+ bool ProtectMachineLocalRequests = false;
+ std::vector<std::string> UnprotectedUris;
+ };
+
+ explicit PasswordSecurity(const Configuration& Config);
+
+ [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; }
+ [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; }
+ [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const;
+
+ bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest);
+
+private:
+ const Configuration m_Config;
+ tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUriHashes;
+};
+
+void passwordsecurity_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h
new file mode 100644
index 000000000..c098f05ad
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h
@@ -0,0 +1,51 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include <zenhttp/security/passwordsecurity.h>
+
+namespace zen {
+
+class PasswordHttpFilter : public IHttpRequestFilter
+{
+public:
+ static constexpr std::string_view TypeName = "password";
+
+ struct Configuration
+ {
+ PasswordSecurity::Configuration PasswordConfig;
+ std::string AuthenticationTypeString;
+ };
+
+ /**
+ * Expected format (Json)
+ * {
+ * "password": { # "Authorization: Basic <username:password base64 encoded>" style
+ * "username": "<username>",
+ * "password": "<password>"
+ * },
+ * "protect-machine-local-requests": false,
+ * "unprotected-uris": [
+ * "/health/",
+ * "/health/info",
+ * "/health/version"
+ * ]
+ * }
+ */
+ static Configuration ReadConfiguration(CbObjectView Config);
+
+ explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config)
+ : m_PasswordSecurity(Config.PasswordConfig)
+ , m_AuthenticationTypeString(Config.AuthenticationTypeString)
+ {
+ }
+
+ virtual Result FilterRequest(HttpServerRequest& Request) override;
+
+private:
+ PasswordSecurity m_PasswordSecurity;
+ const std::string m_AuthenticationTypeString;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
new file mode 100644
index 000000000..bc3293282
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/refcount.h>
+#include <zencore/iobuffer.h>
+
+#include <cstdint>
+#include <span>
+#include <string_view>
+
+namespace zen {
+
+enum class WebSocketOpcode : uint8_t
+{
+ kText = 0x1,
+ kBinary = 0x2,
+ kClose = 0x8,
+ kPing = 0x9,
+ kPong = 0xA
+};
+
+struct WebSocketMessage
+{
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ IoBuffer Payload;
+ uint16_t CloseCode = 0;
+};
+
+/**
+ * Represents an active WebSocket connection
+ *
+ * Derived classes implement the actual transport (e.g. ASIO sockets).
+ * Instances are reference-counted so that both the service layer and
+ * the async read/write loop can share ownership.
+ */
+class WebSocketConnection : public RefCounted
+{
+public:
+ virtual ~WebSocketConnection() = default;
+
+ virtual void SendText(std::string_view Text) = 0;
+ virtual void SendBinary(std::span<const uint8_t> Data) = 0;
+ virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0;
+ virtual bool IsOpen() const = 0;
+};
+
+/**
+ * Interface for services that accept WebSocket upgrades
+ *
+ * An HttpService may additionally implement this interface to indicate
+ * it supports WebSocket connections. The HTTP server checks for this
+ * via dynamic_cast when it sees an Upgrade: websocket request.
+ */
+class IWebSocketHandler
+{
+public:
+ virtual ~IWebSocketHandler() = default;
+
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
+ virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp
index b097a0d3f..2370def0c 100644
--- a/src/zenhttp/monitoring/httpstats.cpp
+++ b/src/zenhttp/monitoring/httpstats.cpp
@@ -3,15 +3,57 @@
#include "zenhttp/httpstats.h"
#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
namespace zen {
-HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats"))
+HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats"))
{
+ if (EnableWebSockets)
+ {
+ m_PushEnabled.store(true);
+ m_PushThread = std::thread([this] { PushThreadFunction(); });
+ }
+}
+
+HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats"))
+{
+ if (EnableWebSockets)
+ {
+ m_PushEnabled.store(true);
+ m_PushTimer = std::make_unique<asio::steady_timer>(IoContext);
+ EnqueuePushTimer();
+ }
}
HttpStatsService::~HttpStatsService()
{
+ Shutdown();
+}
+
+void
+HttpStatsService::Shutdown()
+{
+ if (!m_PushEnabled.exchange(false))
+ {
+ return;
+ }
+
+ if (m_PushTimer)
+ {
+ m_PushTimer->cancel();
+ m_PushTimer.reset();
+ }
+
+ if (m_PushThread.joinable())
+ {
+ m_PushEvent.Set();
+ m_PushThread.join();
+ }
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); });
}
const char*
@@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro
void
HttpStatsService::HandleRequest(HttpServerRequest& Request)
{
+ ZEN_TRACE_CPU("HttpStatsService::HandleRequest");
using namespace std::literals;
std::string_view Key = Request.RelativeUri();
@@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request)
}
}
+//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+void
+HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen");
+ ZEN_INFO("Stats WebSocket client connected");
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+
+ // Send initial state immediately
+ if (m_PushThread.joinable())
+ {
+ m_PushEvent.Set();
+ }
+}
+
+void
+HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/)
+{
+ // No client-to-server messages expected
+}
+
+void
+HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose");
+ ZEN_INFO("Stats WebSocket client disconnected (code {})", Code);
+
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Stats broadcast
+//
+
+void
+HttpStatsService::BroadcastStats()
+{
+ ZEN_TRACE_CPU("HttpStatsService::BroadcastStats");
+ std::vector<Ref<WebSocketConnection>> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; });
+
+ if (Connections.empty())
+ {
+ return;
+ }
+
+ // Collect stats from all providers
+ ExtendableStringBuilder<4096> JsonBuilder;
+ JsonBuilder.Append("{");
+
+ bool First = true;
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ for (auto& [Id, Provider] : m_Providers)
+ {
+ CbObject Stats = Provider->CollectStats();
+ if (!Stats)
+ {
+ continue;
+ }
+
+ if (!First)
+ {
+ JsonBuilder.Append(",");
+ }
+ First = false;
+
+ // Emit as "provider_id": { ... }
+ JsonBuilder.Append("\"");
+ JsonBuilder.Append(Id);
+ JsonBuilder.Append("\":");
+
+ ExtendableStringBuilder<2048> StatsJson;
+ Stats.ToJson(StatsJson);
+ JsonBuilder.Append(StatsJson.ToView());
+ }
+ }
+
+ JsonBuilder.Append("}");
+
+ std::string_view Json = JsonBuilder.ToView();
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Json);
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Thread-based push (fallback when no io_context)
+//
+
+void
+HttpStatsService::PushThreadFunction()
+{
+ SetCurrentThreadName("stats_ws_push");
+
+ while (m_PushEnabled.load())
+ {
+ m_PushEvent.Wait(5000);
+ m_PushEvent.Reset();
+
+ if (!m_PushEnabled.load())
+ {
+ break;
+ }
+
+ BroadcastStats();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Timer-based push (when io_context is provided)
+//
+
+void
+HttpStatsService::EnqueuePushTimer()
+{
+ if (!m_PushTimer)
+ {
+ return;
+ }
+
+ m_PushTimer->expires_after(std::chrono::seconds(5));
+ m_PushTimer->async_wait([this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return;
+ }
+
+ BroadcastStats();
+ EnqueuePushTimer();
+ });
+}
+
} // namespace zen
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
index 708238224..cbfe4d889 100644
--- a/src/zenhttp/packageformat.cpp
+++ b/src/zenhttp/packageformat.cpp
@@ -581,7 +581,7 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
- Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
+ Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash));
}
else
{
@@ -805,6 +805,8 @@ CbPackageReader::Finalize()
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.packageformat");
+
TEST_CASE("CbPackage.Serialization")
{
// Make a test package
@@ -926,6 +928,8 @@ TEST_CASE("CbPackage.LocalRef")
Reader.Finalize();
}
+TEST_SUITE_END();
+
void
forcelink_packageformat()
{
diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp
new file mode 100644
index 000000000..0e3a743c3
--- /dev/null
+++ b/src/zenhttp/security/passwordsecurity.cpp
@@ -0,0 +1,176 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenhttp/security/passwordsecurity.h"
+#include <zencore/compactbinaryutil.h>
+#include <zencore/fmtutils.h>
+#include <zencore/string.h>
+
+#if ZEN_WITH_TESTS
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/testing.h>
+#endif // ZEN_WITH_TESTS
+
+namespace zen {
+using namespace std::literals;
+
+PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config)
+{
+ m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size());
+ for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++)
+ {
+ const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index];
+ if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second)
+ {
+ throw std::runtime_error(fmt::format(
+ "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')",
+ Index + 1,
+ UnprotectedUri,
+ Result.first->second + 1,
+ m_Config.UnprotectedUris[Result.first->second]));
+ }
+ }
+}
+
+bool
+PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const
+{
+ if (!m_Config.UnprotectedUris.empty())
+ {
+ uint32_t UriHash = HashStringDjb2(std::array<const std::string_view, 2>{BaseUri, RelativeUri});
+ if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end())
+ {
+ const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second];
+ if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length())
+ {
+ if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri)
+ {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+bool
+PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest)
+{
+ if (IsUnprotectedUri(BaseUri, RelativeUri))
+ {
+ return true;
+ }
+ if (!ProtectMachineLocalRequests() && IsMachineLocalRequest)
+ {
+ return true;
+ }
+ if (Password().empty())
+ {
+ return true;
+ }
+ if (Password() == InPassword)
+ {
+ return true;
+ }
+ return false;
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("http.passwordsecurity");
+
+TEST_CASE("passwordsecurity.allowanything")
+{
+ PasswordSecurity Anything({});
+ CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+}
+
+TEST_CASE("passwordsecurity.allowalllocal")
+{
+ PasswordSecurity AllLocal({.Password = "123456"});
+ CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+}
+
+TEST_CASE("passwordsecurity.allowonlypassword")
+{
+ PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true});
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+}
+
+TEST_CASE("passwordsecurity.allowsomeexternaluris")
+{
+ PasswordSecurity AllLocal(
+ {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})});
+ CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+}
+
+TEST_CASE("passwordsecurity.allowsomelocaluris")
+{
+ PasswordSecurity AllLocal(
+ {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})});
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+}
+
+TEST_CASE("passwordsecurity.conflictingunprotecteduris")
+{
+ try
+ {
+ PasswordSecurity AllLocal({.Password = "123456",
+ .ProtectMachineLocalRequests = true,
+ .UnprotectedUris = std::vector<std::string>({"/free/access", "/free/access"})});
+ CHECK(false);
+ }
+ catch (const std::runtime_error& Ex)
+ {
+ CHECK_EQ(Ex.what(),
+ std::string("password security unprotected uris does not generate unique hashes. Uri #2 ('/free/access') collides with "
+ "uri #1 ('/free/access')"));
+ }
+}
+
+TEST_SUITE_END();
+
+void
+passwordsecurity_forcelink()
+{
+}
+#endif // ZEN_WITH_TESTS
+
+} // namespace zen
diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp
new file mode 100644
index 000000000..87d8cc275
--- /dev/null
+++ b/src/zenhttp/security/passwordsecurityfilter.cpp
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenhttp/security/passwordsecurityfilter.h"
+
+#include <zencore/base64.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/fmtutils.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+PasswordHttpFilter::Configuration
+PasswordHttpFilter::ReadConfiguration(CbObjectView Config)
+{
+ Configuration Result;
+ if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType)
+ {
+ Result.AuthenticationTypeString = "Basic ";
+ std::string_view Username = PasswordType["username"sv].AsString();
+ std::string_view Password = PasswordType["password"sv].AsString();
+ std::string UsernamePassword = fmt::format("{}:{}", Username, Password);
+ Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length())));
+ Base64::Encode(reinterpret_cast<const uint8_t*>(UsernamePassword.data()),
+ uint32_t(UsernamePassword.size()),
+ const_cast<char*>(Result.PasswordConfig.Password.data()));
+ }
+ Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool();
+ Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-uris"sv, Config);
+ return Result;
+}
+
+IHttpRequestFilter::Result
+PasswordHttpFilter::FilterRequest(HttpServerRequest& Request)
+{
+ std::string_view Password;
+ std::string_view AuthorizationHeader = Request.GetAuthorizationHeader();
+ size_t AuthorizationHeaderLength = AuthorizationHeader.length();
+ if (AuthorizationHeaderLength > m_AuthenticationTypeString.length())
+ {
+ if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0)
+ {
+ Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length());
+ }
+ }
+
+ bool IsAllowed =
+ m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest());
+ if (IsAllowed)
+ {
+ return Result::Accepted;
+ }
+ return Result::Forbidden;
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 18a0f6a40..f5178ebe8 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -7,12 +7,15 @@
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/memory/llm.h>
+#include <zencore/system.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zencore/windows.h>
#include <zenhttp/httpserver.h>
#include "httpparser.h"
+#include "wsasio.h"
+#include "wsframecodec.h"
#include <EASTL/fixed_vector.h>
@@ -89,15 +92,19 @@ IsIPv6AvailableSysctl(void)
char buf[16];
if (fgets(buf, sizeof(buf), f))
{
- fclose(f);
// 0 means IPv6 enabled, 1 means disabled
val = atoi(buf);
}
+ fclose(f);
}
return val == 0;
}
+#endif // ZEN_PLATFORM_LINUX
+namespace zen {
+
+#if ZEN_PLATFORM_LINUX
bool
IsIPv6Capable()
{
@@ -121,8 +128,6 @@ IsIPv6Capable()
}
#endif
-namespace zen {
-
const FLLMTag&
GetHttpasioTag()
{
@@ -145,7 +150,7 @@ inline LoggerRef
InitLogger()
{
LoggerRef Logger = logging::Get("asio");
- // Logger.set_level(spdlog::level::trace);
+ // Logger.SetLogLevel(logging::Trace);
return Logger;
}
@@ -496,16 +501,21 @@ public:
HttpAsioServerImpl();
~HttpAsioServerImpl();
- void Initialize(std::filesystem::path DataDir);
- int Start(uint16_t Port, const AsioConfig& Config);
- void Stop();
- void RegisterService(const char* UrlPath, HttpService& Service);
- HttpService* RouteRequest(std::string_view Url);
+ void Initialize(std::filesystem::path DataDir);
+ int Start(uint16_t Port, const AsioConfig& Config);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
+
+ bool IsLoopbackOnly() const;
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -518,6 +528,11 @@ public:
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
+
+ HttpServer* m_HttpServer = nullptr;
};
/**
@@ -527,12 +542,21 @@ public:
class HttpAsioServerRequest : public HttpServerRequest
{
public:
- HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber);
+ HttpAsioServerRequest(HttpRequestParser& Request,
+ HttpService& Service,
+ IoBuffer PayloadBuffer,
+ uint32_t RequestNumber,
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress);
~HttpAsioServerRequest();
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -548,6 +572,8 @@ public:
HttpRequestParser& m_Request;
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
+ bool m_IsLocalMachineRequest;
+ std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -925,6 +951,7 @@ private:
void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop);
void CloseConnection();
+ void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {});
HttpAsioServerImpl& m_Server;
asio::streambuf m_RequestBuffer;
@@ -1025,6 +1052,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+ m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
m_RequestCounter.load(std::memory_order_relaxed),
@@ -1078,6 +1107,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
return;
}
+ m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
RequestNumber,
@@ -1139,10 +1170,91 @@ HttpServerConnection::CloseConnection()
}
void
+HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
+{
+ ExtendableStringBuilder<256> ResponseBuilder;
+ ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
+ if (!Headers.empty())
+ {
+ ResponseBuilder << Headers;
+ }
+ if (!m_RequestData.IsKeepAlive())
+ {
+ ResponseBuilder << "Connection: close\r\n";
+ }
+ ResponseBuilder << "\r\n";
+ if (!Body.empty())
+ {
+ ResponseBuilder << Body;
+ }
+ auto ResponseView = ResponseBuilder.ToView();
+ IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
+ auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
+ asio::async_write(
+ *m_Socket.get(),
+ Buffer,
+ [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
+ });
+}
+
+void
HttpServerConnection::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ // WebSocket upgrade detection must happen before the keep-alive check below,
+ // because Upgrade requests have "Connection: Upgrade" which the HTTP parser
+ // treats as non-keep-alive, causing a premature shutdown of the receive side.
+ if (m_RequestData.IsWebSocketUpgrade())
+ {
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service);
+ if (WsHandler && !m_RequestData.SecWebSocketKey().empty())
+ {
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey());
+
+ auto ResponseStr = std::make_shared<std::string>();
+ ResponseStr->reserve(256);
+ ResponseStr->append(
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ");
+ ResponseStr->append(AcceptKey);
+ ResponseStr->append("\r\n\r\n");
+
+ // Send the 101 response on the current socket, then hand the socket off
+ // to a WsAsioConnection for the WebSocket protocol.
+ asio::async_write(*m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ Ref<WsAsioConnection> WsConn(
+ new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
+
+ m_RequestState = RequestState::kDone;
+ return;
+ }
+ }
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
+
if (!m_RequestData.IsKeepAlive())
{
m_RequestState = RequestState::kWritingFinal;
@@ -1166,14 +1278,24 @@ HttpServerConnection::HandleRequest()
{
ZEN_TRACE_CPU("asio::HandleRequest");
- HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber);
+ m_Server.m_HttpServer->MarkRequest();
+
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+
+ HttpAsioServerRequest Request(m_RequestData,
+ *Service,
+ m_RequestData.Body(),
+ RequestNumber,
+ IsLocalConnection,
+ RemoteEndpoint.address().to_string());
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server.m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server.m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -1188,56 +1310,73 @@ HttpServerConnection::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
{
+ if (Request.ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength()));
+ }
+
// Transmit the response
if (m_RequestData.RequestVerb() == HttpVerb::kHead)
@@ -1278,51 +1417,24 @@ HttpServerConnection::HandleRequest()
}
}
- if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect();
+ if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty()))
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "\r\n"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Connection: close\r\n"
- "\r\n"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ ExtendableStringBuilder<128> Headers;
+ Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n";
+ SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView());
+ }
+ else if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ SendInlineResponse(RequestNumber, "404 NOT FOUND"sv);
}
else
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "\r\n"
- "No suitable route found"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "Connection: close\r\n"
- "\r\n"
- "No suitable route found"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ SendInlineResponse(RequestNumber,
+ "404 NOT FOUND"sv,
+ "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv,
+ "No suitable route found"sv);
}
}
@@ -1348,8 +1460,11 @@ struct HttpAcceptor
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
#else // ZEN_PLATFORM_WINDOWS
- m_Acceptor.set_option(asio::socket_base::reuse_address(false));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false));
+ // Allow binding to a port in TIME_WAIT so the server can restart immediately
+ // after a previous instance exits. On Linux this does not allow two processes
+ // to actively listen on the same port simultaneously.
+ m_Acceptor.set_option(asio::socket_base::reuse_address(true));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true));
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
@@ -1512,7 +1627,7 @@ struct HttpAcceptor
{
ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
- return 0;
+ return {};
}
if (EffectivePort != BasePort)
@@ -1569,7 +1684,8 @@ struct HttpAcceptor
void StopAccepting() { m_IsStopped = true; }
- int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
+ int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
bool IsValid() const { return m_IsValid; }
@@ -1632,11 +1748,15 @@ private:
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
HttpService& Service,
IoBuffer PayloadBuffer,
- uint32_t RequestNumber)
+ uint32_t RequestNumber,
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress)
: HttpServerRequest(Service)
, m_Request(Request)
, m_RequestNumber(RequestNumber)
, m_PayloadBuffer(std::move(PayloadBuffer))
+, m_IsLocalMachineRequest(IsLocalMachineRequest)
+, m_RemoteAddress(std::move(RemoteAddress))
{
const int PrefixLength = Service.UriPrefixLength();
@@ -1708,6 +1828,24 @@ HttpAsioServerRequest::ParseRequestId() const
return m_Request.RequestId();
}
+bool
+HttpAsioServerRequest::IsLocalMachineRequest() const
+{
+ return m_IsLocalMachineRequest;
+}
+
+std::string_view
+HttpAsioServerRequest::GetRemoteAddress() const
+{
+ return m_RemoteAddress;
+}
+
+std::string_view
+HttpAsioServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
IoBuffer
HttpAsioServerRequest::ReadPayload()
{
@@ -1904,6 +2042,37 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+void
+HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+
+ return RequestFilter->FilterRequest(Request);
+}
+
+bool
+HttpAsioServerImpl::IsLoopbackOnly() const
+{
+ return m_Acceptor && m_Acceptor->IsLoopbackOnly();
+}
+
} // namespace zen::asio_http
//////////////////////////////////////////////////////////////////////////
@@ -1916,11 +2085,15 @@ public:
HttpAsioServer(const AsioConfig& Config);
~HttpAsioServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
private:
Event m_ShutdownEvent;
@@ -1934,6 +2107,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config)
: m_InitialConfig(Config)
, m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
{
+ m_Impl->m_HttpServer = this;
ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser));
}
@@ -1965,6 +2139,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service)
m_Impl->RegisterService(Service.BaseUri(), Service);
}
+void
+HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ m_Impl->SetHttpRequestFilter(RequestFilter);
+}
+
int
HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
@@ -1989,10 +2169,46 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
return m_BasePort;
}
+std::string
+HttpAsioServer::OnGetExternalHost() const
+{
+ if (m_Impl->IsLoopbackOnly())
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesReceived() const
+{
+ return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesSent() const
+{
+ return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpAsioServer::OnRun(bool IsInteractive)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractive)
@@ -2008,12 +2224,13 @@ HttpAsioServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractive)
{
@@ -2022,8 +2239,8 @@ HttpAsioServer::OnRun(bool IsInteractive)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index c483dfc28..3ec1141a7 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -15,4 +15,6 @@ struct AsioConfig
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
+bool IsIPv6Capable();
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 31cb04be5..584e06cbf 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -54,9 +54,19 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
}
void
+HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ for (auto& Server : m_Servers)
+ {
+ Server->SetHttpRequestFilter(RequestFilter);
+ }
+}
+
+void
HttpMultiServer::OnRun(bool IsInteractiveSession)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractiveSession)
@@ -72,12 +82,13 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractiveSession)
{
@@ -86,8 +97,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
@@ -106,6 +117,16 @@ HttpMultiServer::OnClose()
}
}
+std::string
+HttpMultiServer::OnGetExternalHost() const
+{
+ if (!m_Servers.empty())
+ {
+ return std::string(m_Servers.front()->GetExternalHost());
+ }
+ return HttpServer::OnGetExternalHost();
+}
+
void
HttpMultiServer::AddServer(Ref<HttpServer> Server)
{
diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h
index ae0ed74cf..97699828a 100644
--- a/src/zenhttp/servers/httpmulti.h
+++ b/src/zenhttp/servers/httpmulti.h
@@ -15,11 +15,13 @@ public:
HttpMultiServer();
~HttpMultiServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
void AddServer(Ref<HttpServer> Server);
diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp
index 0ec1cb3c4..9bb7ef3bc 100644
--- a/src/zenhttp/servers/httpnull.cpp
+++ b/src/zenhttp/servers/httpnull.cpp
@@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service)
ZEN_UNUSED(Service);
}
+void
+HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_UNUSED(RequestFilter);
+}
+
int
HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
@@ -34,7 +40,8 @@ HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
void
HttpNullServer::OnRun(bool IsInteractiveSession)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractiveSession)
@@ -50,12 +57,13 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractiveSession)
{
@@ -64,8 +72,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h
index ce7230938..52838f012 100644
--- a/src/zenhttp/servers/httpnull.h
+++ b/src/zenhttp/servers/httpnull.h
@@ -18,6 +18,7 @@ public:
~HttpNullServer();
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 93094e21b..918b55dc6 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -12,13 +12,17 @@ namespace zen {
using namespace std::literals;
-static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
-static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
-static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
-static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
-static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
-static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
-static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv);
+static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv);
+static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv);
//////////////////////////////////////////////////////////////////////////
//
@@ -142,41 +146,62 @@ HttpRequestParser::ParseCurrentHeader()
const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1);
- if (HeaderHash == HashContentLength)
+ switch (HeaderHash)
{
- m_ContentLengthHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAccept)
- {
- m_AcceptHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashContentType)
- {
- m_ContentTypeHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashSession)
- {
- m_SessionId = Oid::TryFromHexString(HeaderValue);
- }
- else if (HeaderHash == HashRequest)
- {
- std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
- }
- else if (HeaderHash == HashExpect)
- {
- if (HeaderValue == "100-continue"sv)
- {
- // We don't currently do anything with this
- m_Expect100Continue = true;
- }
- else
- {
- ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
- }
- }
- else if (HeaderHash == HashRange)
- {
- m_RangeHeaderIndex = CurrentHeaderIndex;
+ case HashContentLength:
+ m_ContentLengthHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAccept:
+ m_AcceptHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashContentType:
+ m_ContentTypeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAuthorization:
+ m_AuthorizationHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSession:
+ m_SessionId = Oid::TryFromHexString(HeaderValue);
+ break;
+
+ case HashRequest:
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ break;
+
+ case HashExpect:
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ break;
+
+ case HashRange:
+ m_RangeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashUpgrade:
+ m_UpgradeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketKey:
+ m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketVersion:
+ m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ default:
+ break;
}
}
@@ -220,11 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl)
NormalizedUrl.reserve(UrlLength);
NormalizedUrl.append(Url, UrlIndex);
}
-
- if (!LastCharWasSeparator)
- {
- NormalizedUrl.push_back('/');
- }
}
else if (!NormalizedUrl.empty())
{
@@ -305,6 +325,7 @@ HttpRequestParser::OnHeadersComplete()
if (ContentLength)
{
+ // TODO: should sanity-check content length here
m_BodyBuffer = IoBuffer(ContentLength);
}
@@ -324,9 +345,9 @@ HttpRequestParser::OnHeadersComplete()
int
HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
- if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
+ if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size())
{
- ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
+ ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes",
(m_BodyPosition + Bytes) - m_BodyBuffer.Size());
return 1;
}
@@ -337,7 +358,7 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
if (m_BodyPosition != m_BodyBuffer.Size())
{
- ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
+ ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
return 1;
}
}
@@ -353,13 +374,18 @@ HttpRequestParser::ResetState()
m_HeaderEntries.clear();
- m_ContentLengthHeaderIndex = -1;
- m_AcceptHeaderIndex = -1;
- m_ContentTypeHeaderIndex = -1;
- m_RangeHeaderIndex = -1;
- m_Expect100Continue = false;
- m_BodyBuffer = {};
- m_BodyPosition = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_AuthorizationHeaderIndex = -1;
+ m_UpgradeHeaderIndex = -1;
+ m_SecWebSocketKeyHeaderIndex = -1;
+ m_SecWebSocketVersionHeaderIndex = -1;
+ m_RequestVerb = HttpVerb::kGet;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
m_HeaderData.clear();
m_NormalizedUrl.clear();
@@ -416,4 +442,21 @@ HttpRequestParser::OnMessageComplete()
}
}
+bool
+HttpRequestParser::IsWebSocketUpgrade() const
+{
+ std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex);
+ if (Upgrade.empty())
+ {
+ return false;
+ }
+
+ // Case-insensitive check for "websocket"
+ if (Upgrade.size() != 9)
+ {
+ return false;
+ }
+ return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
+}
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 0d2664ec5..23ad9d8fb 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -46,6 +46,12 @@ struct HttpRequestParser
std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); }
+ std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); }
+
+ std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); }
+ std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); }
+ bool IsWebSocketUpgrade() const;
+
private:
struct HeaderRange
{
@@ -83,7 +89,11 @@ private:
int8_t m_AcceptHeaderIndex;
int8_t m_ContentTypeHeaderIndex;
int8_t m_RangeHeaderIndex;
- HttpVerb m_RequestVerb;
+ int8_t m_AuthorizationHeaderIndex;
+ int8_t m_UpgradeHeaderIndex;
+ int8_t m_SecWebSocketKeyHeaderIndex;
+ int8_t m_SecWebSocketVersionHeaderIndex;
+ HttpVerb m_RequestVerb = HttpVerb::kGet;
std::atomic_bool m_KeepAlive{false};
bool m_Expect100Continue = false;
int m_RequestId = -1;
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index b9217ed87..4bf8c61bb 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
// HttpPluginServer
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
@@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
virtual void AddPlugin(Ref<TransportPlugin> Plugin) override;
virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override;
- HttpService* RouteRequest(std::string_view Url);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
struct ServiceEntry
{
@@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
HttpService* Service;
};
- bool m_IsInitialized = false;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+ bool m_IsInitialized = false;
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
std::vector<Ref<TransportPlugin>> m_Plugins;
@@ -120,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
bool m_IsRequestLoggingEnabled = false;
LoggerRef m_RequestLog;
std::atomic_uint32_t m_ConnectionIdCounter{0};
- int m_BasePort;
+ int m_BasePort = 0;
HttpServerTracer m_RequestTracer;
@@ -143,8 +146,11 @@ public:
HttpPluginServerRequest(const HttpPluginServerRequest&) = delete;
HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
+ // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection
+ virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -288,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug
ConnectionName = "anonymous";
}
- ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName);
+ ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName);
}
uint32_t
@@ -372,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest()
{
ZEN_TRACE_CPU("http_plugin::HandleRequest");
+ m_Server->MarkRequest();
+
HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body());
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server->m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -392,53 +400,65 @@ HttpPluginConnectionHandler::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response))
{
@@ -462,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest()
const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber),
ResponseBuffers);
@@ -618,6 +638,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest()
{
}
+std::string_view
+HttpPluginServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
Oid
HttpPluginServerRequest::ParseSessionId() const
{
@@ -750,6 +776,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir
}
void
+HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+void
HttpPluginServerImpl::OnClose()
{
if (!m_IsInitialized)
@@ -806,6 +839,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
@@ -894,6 +928,22 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+IHttpRequestFilter::Result
+HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
//////////////////////////////////////////////////////////////////////////
struct HttpPluginServerImpl;
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 54cc0c22d..dfe6bb6aa 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -12,6 +12,7 @@
#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zenhttp/packageformat.h>
@@ -25,7 +26,9 @@
# include <zencore/workthreadpool.h>
# include "iothreadpool.h"
+# include <atomic>
# include <http.h>
+# include <asio.hpp> // for resolving addresses for GetExternalHost
namespace zen {
@@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
OutString.Append("unknown");
}
+class HttpSysServerRequest;
+
/**
* @brief Windows implementation of HTTP server based on http.sys
*
@@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
class HttpSysServer : public HttpServer
{
friend class HttpSysTransaction;
+ friend class HttpMessageResponseRequest;
+ friend struct InitialRequestHandler;
public:
explicit HttpSysServer(const HttpSysConfig& Config);
@@ -90,17 +97,23 @@ public:
// HttpServer interface implementation
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool TestMode) override;
- virtual void OnRequestExit() override;
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnClose() override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool TestMode) override;
+ virtual void OnRequestExit() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
WorkerThreadPool& WorkPool();
inline bool IsOk() const { return m_IsOk; }
inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+ IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request);
+
private:
int InitializeServer(int BasePort);
void Cleanup();
@@ -124,8 +137,8 @@ private:
std::unique_ptr<WinIoThreadPool> m_IoThreadPool;
- RwLock m_AsyncWorkPoolInitLock;
- WorkerThreadPool* m_AsyncWorkPool = nullptr;
+ RwLock m_AsyncWorkPoolInitLock;
+ std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
@@ -137,6 +150,12 @@ private:
int32_t m_MaxPendingRequests = 128;
Event m_ShutdownEvent;
HttpSysConfig m_InitialConfig;
+
+ RwLock m_RequestFilterLock;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
};
} // namespace zen
@@ -144,6 +163,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -313,6 +336,10 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -320,16 +347,19 @@ public:
virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
virtual bool TryGetRanges(HttpRanges& Ranges) override;
+ void LogRequest(HttpMessageResponseRequest* Response);
+
using HttpServerRequest::WriteResponse;
HttpSysServerRequest(const HttpSysServerRequest&) = delete;
HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+ mutable ExtendableStringBuilder<64> m_RemoteAddress;
};
/** HTTP transaction
@@ -363,7 +393,7 @@ public:
PTP_IO Iocp();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -380,8 +410,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -418,7 +448,10 @@ public:
virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
void SuppressResponseBody(); // typically used for HEAD requests
- inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+ inline uint16_t GetResponseCode() const { return m_ResponseCode; }
+ inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+
+ void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -429,6 +462,7 @@ private:
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
+ std::string m_LocationHeader;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody()
HttpSysRequestHandler*
HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
- ZEN_UNUSED(NumberOfBytesTransferred);
+ Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
if (IoResult != NO_ERROR)
{
@@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+ // Location header (for redirects)
+
+ if (!m_LocationHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation];
+ LocationHeader->pRawValue = m_LocationHeader.data();
+ LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
HTTP_CACHE_POLICY CachePolicy;
- CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
+ CachePolicy.Policy = HttpCachePolicyNocache;
CachePolicy.SecondsToLive = 0;
// Initial response API call
- SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
- &HttpResponse,
- &CachePolicy,
- NULL,
- NULL,
- 0,
- Tx.Overlapped(),
- NULL);
+ SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
+ &HttpResponse, // HttpResponse
+ &CachePolicy, // CachePolicy
+ NULL, // BytesSent
+ NULL, // Reserved1
+ 0, // Reserved2
+ Tx.Overlapped(), // Overlapped
+ NULL // LogData
+ );
m_IsInitialResponse = false;
}
@@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
{
// Subsequent response API calls
- SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
+ SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
(USHORT)ThisRequestChunkCount, // EntityChunkCount
&m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
NULL, // BytesSent
@@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}",
+ GetSystemErrorAsString(IoResult),
+ IoResult,
+ NumberOfBytesTransferred);
return this;
}
@@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer()
ZEN_ERROR("~HttpSysServer() called without calling Close() first");
}
- delete m_AsyncWorkPool;
+ auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed);
m_AsyncWorkPool = nullptr;
+
+ delete WorkPool;
}
void
@@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
return 0;
}
@@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort)
if ((Result == ERROR_SHARING_VIOLATION))
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
@@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort)
{
for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
}
@@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort)
// port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
- ZEN_WARN(
- "Unable to register handler using '{}' - falling back to local-only. "
- "Please ensure the appropriate netsh URL reservation configuration "
- "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
- WideToUtf8(WildcardUrlPath));
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation configuration "
+ "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
+ WideToUtf8(WildcardUrlPath));
+ }
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ bool ShouldRetryNextPort = true;
+ for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset)
{
- EffectivePort = BasePort + (PortOffset * 100);
+ EffectivePort = BasePort + (PortOffset * 100);
+ ShouldRetryNextPort = false;
for (const std::u8string_view Host : Hosts)
{
WideStringBuilder<64> LocalUrlPath;
LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
if (InternalResult == NO_ERROR)
{
@@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort)
m_BaseUris.push_back(LocalUrlPath.c_str());
}
+ else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
+ {
+ // Port may be owned by another process's wildcard registration (access denied)
+ // or actively in use (sharing violation) — retry on a different port
+ ShouldRetryNextPort = true;
+ }
else
{
- break;
+ ZEN_WARN("Failed to register local handler '{}': {} ({:#x})",
+ WideToUtf8(LocalUrlPath),
+ GetSystemErrorAsString(InternalResult),
+ InternalResult);
}
}
+
+ if (!m_BaseUris.empty())
+ {
+ break;
+ }
}
}
else
@@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (m_BaseUris.empty())
{
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result);
+ ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result);
}
}
@@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort)
ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
}
- // This is not available in all Windows SDK versions so for now we can't use recently
- // released functionality. We should investigate how to get more recent SDK releases
- // into the build
-
-# if 0
- if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4))
- {
- ZEN_DEBUG("HTTP3 is available");
- }
- else
- {
- ZEN_DEBUG("HTTP3 is NOT available");
- }
-# endif
-
return EffectivePort;
}
@@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool()
{
ZEN_MEMSCOPE(GetHttpsysTag());
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_acquire))
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_relaxed))
{
- m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async");
+ m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release);
}
}
- return *m_AsyncWorkPool;
+ return *m_AsyncWorkPool.load(std::memory_order_relaxed);
}
void
@@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive)
ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit");
}
+ bool ShutdownRequested = false;
do
{
- // int WaitTimeout = -1;
int WaitTimeout = 100;
if (IsInteractive)
@@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
UpdateLofreqTimerValue();
- } while (!IsApplicationExitRequested());
+ } while (!ShutdownRequested);
}
void
@@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
{
@@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
{
HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+ m_HttpServer.MarkRequest();
+
// Default request handling
# if ZEN_WITH_OTEL
@@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator);
# endif
- if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
+ {
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+ }
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ ThisRequest.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
{
- Service.HandleRequest(ThisRequest);
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
}
return ThisRequest;
@@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const
return 0;
}
+bool
+HttpSysServerRequest::IsLocalMachineRequest() const
+{
+ const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress;
+ const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ if (LocalAddress->sa_family != RemoteAddress->sa_family)
+ {
+ return false;
+ }
+ if (LocalAddress->sa_family == AF_INET)
+ {
+ const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress);
+ const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress);
+ return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr;
+ }
+ else if (LocalAddress->sa_family == AF_INET6)
+ {
+ const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress);
+ const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress);
+ return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+std::string_view
+HttpSysServerRequest::GetRemoteAddress() const
+{
+ if (m_RemoteAddress.Size() == 0)
+ {
+ const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false);
+ }
+ return m_RemoteAddress.ToView();
+}
+
+std::string_view
+HttpSysServerRequest::GetAuthorizationHeader() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization];
+ return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength);
+}
+
IoBuffer
HttpSysServerRequest::ReadPayload()
{
@@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
if (SuppressBody())
{
@@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
if (SuppressBody())
{
@@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
+}
+
+void
+HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response)
+{
+ if (ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}",
+ ToString(RequestVerb()),
+ m_UriUtf8.c_str(),
+ Response->GetResponseCode(),
+ NiceBytes(Response->GetResponseBodySize()));
+ }
}
void
@@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
break;
}
+ Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
+
ZEN_TRACE_CPU("httpsys::HandleCompletion");
// Route request
@@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
- ZEN_INFO("");
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Transaction().Server().OnWebSocketConnectionOpened();
+ Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle,
+ RequestId,
+ *WsHandler,
+ Transaction().Iocp(),
+ &Transaction().Server()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
+
+ ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
+
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
@@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
+ else
+ {
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect();
+ std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength);
+ if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty()))
+ {
+ auto* Response = new HttpMessageResponseRequest(Transaction(), 302);
+ Response->SetLocationHeader(DefaultRedirect);
+ return Response;
+ }
+ }
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
@@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit()
m_ShutdownEvent.Set();
}
+std::string
+HttpSysServer::OnGetExternalHost() const
+{
+ // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback
+ bool IsPublic = false;
+ for (const auto& Uri : m_BaseUris)
+ {
+ if (Uri.find(L'*') != std::wstring::npos)
+ {
+ IsPublic = true;
+ break;
+ }
+ }
+
+ if (!IsPublic)
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesReceived() const
+{
+ return m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesSent() const
+{
+ return m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpSysServer::OnRegisterService(HttpService& Service)
{
RegisterService(Service.BaseUri(), Service);
}
+void
+HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_RequestFilterLock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpSysServer::FilterRequest(HttpSysServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_RequestFilterLock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
Ref<HttpServer>
CreateHttpSysServer(HttpSysConfig Config)
{
diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h
new file mode 100644
index 000000000..4f8a97012
--- /dev/null
+++ b/src/zenhttp/servers/httpsys_iocontext.h
@@ -0,0 +1,40 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+
+# include <cstdint>
+
+namespace zen {
+
+/**
+ * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch
+ *
+ * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection
+ * (for WebSocket read/write) embed this struct. The single IoCompletionCallback
+ * bound to the request queue uses the ContextType tag to dispatch to the correct
+ * handler.
+ *
+ * The Overlapped member must be first so that CONTAINING_RECORD works to recover
+ * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool.
+ */
+struct HttpSysIoContext
+{
+ OVERLAPPED Overlapped{};
+
+ enum class Type : uint8_t
+ {
+ kTransaction,
+ kWebSocketRead,
+ kWebSocketWrite,
+ } ContextType = Type::kTransaction;
+
+ void* Owner = nullptr;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h
index da72c79c9..a9a45f162 100644
--- a/src/zenhttp/servers/httptracer.h
+++ b/src/zenhttp/servers/httptracer.h
@@ -1,9 +1,9 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenhttp/httpserver.h>
-
#pragma once
+#include <zenhttp/httpserver.h>
+
namespace zen {
/** Helper class for HTTP server implementations
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
new file mode 100644
index 000000000..b2543277a
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -0,0 +1,311 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsasio.h"
+#include "wsframecodec.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen::asio_http {
+
+static LoggerRef
+WsLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+: m_Socket(std::move(Socket))
+, m_Handler(Handler)
+, m_HttpServer(Server)
+{
+}
+
+WsAsioConnection::~WsAsioConnection()
+{
+ m_IsOpen.store(false);
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsAsioConnection::Start()
+{
+ EnqueueRead();
+}
+
+bool
+WsAsioConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Read loop
+//
+
+void
+WsAsioConnection::EnqueueRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
+ Self->OnDataReceived(Ec, ByteCount);
+ });
+}
+
+void
+WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+}
+
+void
+WsAsioConnection::ProcessReceivedData()
+{
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* Data = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size);
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Shut down the socket
+ std::error_code ShutdownEc;
+ m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
+ m_Socket->close(ShutdownEc);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Write queue
+//
+
+void
+WsAsioConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+}
+
+void
+WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsAsioConnection::FlushWriteQueue()
+{
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ // Move Frame into a shared_ptr so we can create the buffer and capture ownership
+ // in the same async_write call without evaluation order issues.
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
+}
+
+void
+WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
new file mode 100644
index 000000000..e8bb3b1d2
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.h
@@ -0,0 +1,77 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/thread.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <memory>
+#include <vector>
+
+namespace zen {
+class HttpServer;
+} // namespace zen
+
+namespace zen::asio_http {
+
+/**
+ * WebSocket connection over an ASIO TCP socket
+ *
+ * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * and runs an async read/write loop to exchange WebSocket frames.
+ *
+ * Lifetime is managed solely through intrusive reference counting (RefCounted).
+ * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
+ * connection alive for the duration of the async operation. The service layer
+ * also holds a Ref<WebSocketConnection>.
+ */
+
+class WsAsioConnection : public WebSocketConnection
+{
+public:
+ WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and the 101 response has been sent.
+ */
+ void Start();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+private:
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void ProcessReceivedData();
+
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount);
+
+ void DoClose(uint16_t Code, std::string_view Reason);
+
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
+
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp
new file mode 100644
index 000000000..e452141fe
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.cpp
@@ -0,0 +1,236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/sha1.h>
+
+#include <cstring>
+#include <random>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+WsFrameParseResult
+WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
+{
+ // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames)
+ if (Size < 2)
+ {
+ return {};
+ }
+
+ const bool Fin = (Data[0] & 0x80) != 0;
+ const uint8_t OpcodeRaw = Data[0] & 0x0F;
+ const bool Masked = (Data[1] & 0x80) != 0;
+ uint64_t PayloadLen = Data[1] & 0x7F;
+
+ size_t HeaderSize = 2;
+
+ if (PayloadLen == 126)
+ {
+ if (Size < 4)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]);
+ HeaderSize = 4;
+ }
+ else if (PayloadLen == 127)
+ {
+ if (Size < 10)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) |
+ (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]);
+ HeaderSize = 10;
+ }
+
+ // Reject frames with unreasonable payload sizes to prevent OOM
+ static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB
+ if (PayloadLen > kMaxPayloadSize)
+ {
+ return {};
+ }
+
+ const size_t MaskSize = Masked ? 4 : 0;
+ const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen;
+
+ if (Size < TotalFrame)
+ {
+ return {};
+ }
+
+ const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr;
+ const uint8_t* PayloadData = Data + HeaderSize + MaskSize;
+
+ WsFrameParseResult Result;
+ Result.IsValid = true;
+ Result.BytesConsumed = TotalFrame;
+ Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw);
+ Result.Fin = Fin;
+
+ Result.Payload.resize(static_cast<size_t>(PayloadLen));
+ if (PayloadLen > 0)
+ {
+ std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen));
+
+ if (Masked)
+ {
+ for (size_t i = 0; i < Result.Payload.size(); ++i)
+ {
+ Result.Payload[i] ^= MaskKey[i & 3];
+ }
+ }
+ }
+
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (server-to-client, no masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length (no mask bit for server frames)
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ Frame.insert(Frame.end(), Payload.begin(), Payload.end());
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (client-to-server, with masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Generate random 4-byte mask key
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ uint32_t MaskValue = s_Rng();
+ uint8_t MaskKey[4];
+ std::memcpy(MaskKey, &MaskValue, 4);
+
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < PayloadLen; ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2)
+//
+
+static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+std::string
+WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey)
+{
+ // Concatenate client key with the magic GUID
+ std::string Combined;
+ Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size());
+ Combined.append(ClientKey);
+ Combined.append(kWebSocketMagicGuid);
+
+ // SHA1 hash
+ SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size());
+
+ // Base64 encode the 20-byte hash
+ char Base64Buf[Base64::GetEncodedDataSize(20) + 1];
+ uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf);
+ Base64Buf[EncodedLen] = '\0';
+
+ return std::string(Base64Buf, EncodedLen);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h
new file mode 100644
index 000000000..2d90b6fa1
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen {
+
+/**
+ * Result of attempting to parse a single WebSocket frame from a byte buffer
+ */
+struct WsFrameParseResult
+{
+ bool IsValid = false; // true if a complete frame was successfully parsed
+ size_t BytesConsumed = 0; // number of bytes consumed from the input buffer
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ bool Fin = false;
+ std::vector<uint8_t> Payload;
+};
+
+/**
+ * RFC 6455 WebSocket frame codec
+ *
+ * Provides static helpers for parsing client-to-server frames (which are
+ * always masked) and building server-to-client frames (which are never masked).
+ */
+struct WsFrameCodec
+{
+ /**
+ * Try to parse one complete frame from the front of the buffer.
+ *
+ * Returns a result with IsValid == false and BytesConsumed == 0 when
+ * there is not enough data yet. The caller should accumulate more data
+ * and retry.
+ */
+ static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size);
+
+ /**
+ * Build a server-to-client frame (no masking)
+ */
+ static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a close frame with a status code and optional reason string
+ */
+ static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Build a client-to-server frame (with masking per RFC 6455)
+ */
+ static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a masked close frame with status code and optional reason
+ */
+ static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2
+ *
+ * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
+ */
+ static std::string ComputeAcceptKey(std::string_view ClientKey);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
new file mode 100644
index 000000000..af320172d
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -0,0 +1,485 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wshttpsys.h"
+
+#if ZEN_WITH_HTTPSYS
+
+# include "wsframecodec.h"
+
+# include <zencore/logging.h>
+# include <zenhttp/httpserver.h>
+
+namespace zen {
+
+static LoggerRef
+WsHttpSysLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws_httpsys");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle,
+ HTTP_REQUEST_ID RequestId,
+ IWebSocketHandler& Handler,
+ PTP_IO Iocp,
+ HttpServer* Server)
+: m_RequestQueueHandle(RequestQueueHandle)
+, m_RequestId(RequestId)
+, m_Handler(Handler)
+, m_Iocp(Iocp)
+, m_HttpServer(Server)
+, m_ReadBuffer(8192)
+{
+ m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead;
+ m_ReadIoContext.Owner = this;
+ m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite;
+ m_WriteIoContext.Owner = this;
+}
+
+WsHttpSysConnection::~WsHttpSysConnection()
+{
+ ZEN_ASSERT(m_OutstandingOps.load() == 0);
+
+ if (m_IsOpen.exchange(false))
+ {
+ Disconnect();
+ }
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsHttpSysConnection::Start()
+{
+ m_SelfRef = Ref<WsHttpSysConnection>(this);
+ IssueAsyncRead();
+}
+
+void
+WsHttpSysConnection::Shutdown()
+{
+ m_ShutdownRequested.store(true, std::memory_order_relaxed);
+
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+bool
+WsHttpSysConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async read path
+//
+
+void
+WsHttpSysConnection::IssueAsyncRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed))
+ {
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ 0, // Flags
+ m_ReadBuffer.data(),
+ (ULONG)m_ReadBuffer.size(),
+ nullptr, // BytesRead (ignored for async)
+ &m_ReadIoContext.Overlapped);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "read issue failed");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef
+ Ref<WsHttpSysConnection> Guard(this);
+
+ if (IoResult != NO_ERROR)
+ {
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ if (IoResult == ERROR_HANDLE_EOF)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection closed");
+ }
+ else if (IoResult != ERROR_OPERATION_ABORTED)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ if (NumberOfBytesTransferred > 0)
+ {
+ m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred);
+ ProcessReceivedData();
+ }
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ IssueAsyncRead();
+ }
+ else
+ {
+ MaybeReleaseSelfRef();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+void
+WsHttpSysConnection::ProcessReceivedData()
+{
+ while (!m_Accumulated.empty())
+ {
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size());
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ // Remove consumed bytes
+ m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+ Disconnect();
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async write path
+//
+
+void
+WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.push_back(std::move(Frame));
+
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ }
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsHttpSysConnection::FlushWriteQueue()
+{
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+
+ m_CurrentWriteBuffer = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk));
+ m_WriteChunk.DataChunkType = HttpDataChunkFromMemory;
+ m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data();
+ m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size();
+
+ ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_MORE_DATA,
+ 1,
+ &m_WriteChunk,
+ nullptr,
+ nullptr,
+ 0,
+ &m_WriteIoContext.Overlapped,
+ nullptr);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+ m_CurrentWriteBuffer.clear();
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ // Hold a transient ref to prevent mid-callback destruction
+ Ref<WsHttpSysConnection> Guard(this);
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+ m_CurrentWriteBuffer.clear();
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Send interface
+//
+
+void
+WsHttpSysConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Lifetime management
+//
+
+void
+WsHttpSysConnection::MaybeReleaseSelfRef()
+{
+ if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ m_SelfRef = nullptr;
+ }
+}
+
+void
+WsHttpSysConnection::Disconnect()
+{
+ // Send final empty body with DISCONNECT to tell http.sys the connection is done
+ HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_DISCONNECT,
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ nullptr,
+ nullptr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h
new file mode 100644
index 000000000..6015e3873
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.h
@@ -0,0 +1,107 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include "httpsys_iocontext.h"
+
+#include <zencore/thread.h>
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <http.h>
+
+# include <atomic>
+# include <deque>
+# include <vector>
+
+namespace zen {
+
+class HttpServer;
+
+/**
+ * WebSocket connection over an http.sys opaque-mode connection
+ *
+ * After the 101 Switching Protocols response is sent with
+ * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the
+ * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody /
+ * HttpSendResponseEntityBody using the original RequestId.
+ *
+ * All I/O is performed asynchronously via the same IOCP threadpool used
+ * for normal http.sys traffic, eliminating per-connection threads.
+ *
+ * Lifetime is managed through intrusive reference counting (RefCounted).
+ * A self-reference (m_SelfRef) is held from Start() until all outstanding
+ * async operations have drained, preventing premature destruction.
+ */
+class WsHttpSysConnection : public WebSocketConnection
+{
+public:
+ WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server);
+ ~WsHttpSysConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and after the 101 response has been sent.
+ */
+ void Start();
+
+ /**
+ * Shut down the connection. Cancels pending I/O; IOCP completions
+ * will fire with ERROR_OPERATION_ABORTED and drain naturally.
+ */
+ void Shutdown();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+ // Called from IoCompletionCallback via tagged dispatch
+ void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+ void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+private:
+ void IssueAsyncRead();
+ void ProcessReceivedData();
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void DoClose(uint16_t Code, std::string_view Reason);
+ void Disconnect();
+ void MaybeReleaseSelfRef();
+
+ HANDLE m_RequestQueueHandle;
+ HTTP_REQUEST_ID m_RequestId;
+ IWebSocketHandler& m_Handler;
+ PTP_IO m_Iocp;
+ HttpServer* m_HttpServer;
+
+ // Tagged OVERLAPPED contexts for concurrent read and write
+ HttpSysIoContext m_ReadIoContext{};
+ HttpSysIoContext m_WriteIoContext{};
+
+ // Read state
+ std::vector<uint8_t> m_ReadBuffer;
+ std::vector<uint8_t> m_Accumulated;
+
+ // Write state
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ std::vector<uint8_t> m_CurrentWriteBuffer;
+ HTTP_DATA_CHUNK m_WriteChunk{};
+ bool m_IsWriting = false;
+
+ // Lifetime management
+ std::atomic<int32_t> m_OutstandingOps{0};
+ Ref<WsHttpSysConnection> m_SelfRef;
+ std::atomic<bool> m_ShutdownRequested{false};
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
new file mode 100644
index 000000000..2134e4ff1
--- /dev/null
+++ b/src/zenhttp/servers/wstest.cpp
@@ -0,0 +1,925 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpwsclient.h>
+# include <zenhttp/websocket.h>
+
+# include "httpasio.h"
+# include "wsframecodec.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# if ZEN_PLATFORM_WINDOWS
+# include <winsock2.h>
+# else
+# include <poll.h>
+# include <sys/socket.h>
+# endif
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+# include <atomic>
+# include <chrono>
+# include <cstring>
+# include <random>
+# include <string>
+# include <string_view>
+# include <thread>
+# include <vector>
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests: WsFrameCodec
+//
+
+TEST_SUITE_BEGIN("http.wstest");
+
+TEST_CASE("websocket.framecodec")
+{
+ SUBCASE("ComputeAcceptKey RFC 6455 test vector")
+ {
+ // RFC 6455 section 4.2.2 example
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+
+ // Server frames are unmasked — TryParseFrame should handle them
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure");
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+
+ SUBCASE("TryParseFrame - partial data returns invalid")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ // Pass only 1 byte — not enough for a frame header
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
+ CHECK_FALSE(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, 0u);
+ }
+
+ SUBCASE("TryParseFrame - empty payload")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK(Result.Payload.empty());
+ }
+
+ SUBCASE("TryParseFrame - masked client frame")
+ {
+ // Build a masked frame manually as a client would send
+ // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello"
+ uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D};
+ uint8_t MaskedPayload[5] = {};
+ const char* Original = "Hello";
+ for (int i = 0; i < 5; ++i)
+ {
+ MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4];
+ }
+
+ std::vector<uint8_t> Frame;
+ Frame.push_back(0x81); // FIN + text
+ Frame.push_back(0x85); // MASK + len=5
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), 5u);
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, masked WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+
+ // Verify mask bit is set
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildMaskedCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure");
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: WebSocket over ASIO
+//
+
+namespace {
+
+ /**
+ * Helper: Build a masked client-to-server frame per RFC 6455
+ */
+ std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+ {
+ std::vector<uint8_t> Frame;
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (Payload.size() < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size()));
+ }
+ else if (Payload.size() <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Mask key (use a fixed key for deterministic tests)
+ uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < Payload.size(); ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+ }
+
+ std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text)
+ {
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ return BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ }
+
+ std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code)
+ {
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+ }
+
+ /**
+ * Test service that implements IWebSocketHandler
+ */
+ struct WsTestService : public HttpService, public IWebSocketHandler
+ {
+ const char* BaseUri() const override { return "/wstest/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest");
+ }
+
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ {
+ m_OpenCount.fetch_add(1);
+
+ m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
+ }
+
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override
+ {
+ m_MessageCount.fetch_add(1);
+
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+
+ // Echo the message back
+ Conn.SendText(Text);
+ }
+ }
+
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+
+ m_ConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_Connections.erase(It, m_Connections.end());
+ });
+ }
+
+ void SendToAll(std::string_view Text)
+ {
+ RwLock::SharedLockScope _(m_ConnectionsLock);
+ for (auto& Conn : m_Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Text);
+ }
+ }
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+
+ RwLock m_ConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_Connections;
+ };
+
+ /**
+ * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket
+ *
+ * Returns true on success (101 response), false otherwise.
+ */
+ bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port)
+ {
+ // Send HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << Path << " HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ // Read the response (look for "101")
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ return Response.find("101") != std::string::npos;
+ }
+
+ /**
+ * Helper: Read a single server-to-client frame from a socket
+ *
+ * Uses a background thread with a synchronous ASIO read and a timeout.
+ */
+ WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000)
+ {
+ std::vector<uint8_t> Buffer;
+ WsFrameParseResult Result;
+ std::atomic<bool> Done{false};
+
+ std::thread Reader([&] {
+ while (!Done.load())
+ {
+ uint8_t Tmp[4096];
+ asio::error_code Ec;
+ size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec);
+ if (Ec || BytesRead == 0)
+ {
+ break;
+ }
+
+ Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead);
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (Frame.IsValid)
+ {
+ Result = std::move(Frame);
+ Done.store(true);
+ return;
+ }
+ }
+ });
+
+ auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs);
+ while (!Done.load() && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ if (!Done.load())
+ {
+ // Timeout — cancel the read
+ asio::error_code Ec;
+ Sock.cancel(Ec);
+ }
+
+ if (Reader.joinable())
+ {
+ Reader.join();
+ }
+
+ return Result;
+ }
+
+} // anonymous namespace
+
+TEST_CASE("websocket.integration")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ // Give server a moment to start accepting
+ Sleep(100);
+
+ SUBCASE("handshake succeeds with 101")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ CHECK(Ok);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_OpenCount.load(), 1);
+
+ Sock.close();
+ }
+
+ SUBCASE("normal HTTP still works alongside WebSocket service")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Send a normal HTTP GET (not upgrade)
+ std::string HttpReq = fmt::format(
+ "GET /wstest/hello HTTP/1.1\r\n"
+ "Host: 127.0.0.1:{}\r\n"
+ "Connection: close\r\n"
+ "\r\n",
+ Port);
+
+ asio::write(Sock, asio::buffer(HttpReq));
+
+ asio::streambuf ResponseBuf;
+ asio::error_code Ec;
+ asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec);
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+ CHECK(Response.find("200") != std::string::npos);
+ }
+
+ SUBCASE("echo message roundtrip")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a text message (masked, as client)
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test");
+ asio::write(Sock, asio::buffer(Frame));
+
+ // Read the echo reply
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, "ping test"sv);
+ CHECK_EQ(TestService.m_MessageCount.load(), 1);
+ CHECK_EQ(TestService.m_LastMessage, "ping test");
+
+ Sock.close();
+ }
+
+ SUBCASE("server push to client")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Server pushes a message
+ TestService.SendToAll("server says hello");
+
+ WsFrameParseResult Frame = ReadOneFrame(Sock);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "server says hello"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("client close handshake")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send close frame
+ std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000);
+ asio::write(Sock, asio::buffer(CloseFrame));
+
+ // Server should echo close back
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_CloseCount.load(), 1);
+ CHECK_EQ(TestService.m_LastCloseCode.load(), 1000);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple concurrent connections")
+ {
+ constexpr int NumClients = 5;
+
+ asio::io_context IoCtx;
+ std::vector<asio::ip::tcp::socket> Sockets;
+
+ for (int i = 0; i < NumClients; ++i)
+ {
+ Sockets.emplace_back(IoCtx);
+ Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port);
+ REQUIRE(Ok);
+ }
+
+ Sleep(100);
+ CHECK_EQ(TestService.m_OpenCount.load(), NumClients);
+
+ // Broadcast from server
+ TestService.SendToAll("broadcast");
+
+ // Each client should receive the message
+ for (int i = 0; i < NumClients; ++i)
+ {
+ WsFrameParseResult Frame = ReadOneFrame(Sockets[i]);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "broadcast"sv);
+ }
+
+ // Close all
+ for (auto& S : Sockets)
+ {
+ S.close();
+ }
+ }
+
+ SUBCASE("service without IWebSocketHandler rejects upgrade")
+ {
+ // Register a plain HTTP service (no WebSocket)
+ struct PlainService : public HttpService
+ {
+ const char* BaseUri() const override { return "/plain/"; }
+ void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); }
+ };
+
+ PlainService Plain;
+ Server->RegisterService(Plain);
+
+ Sleep(50);
+
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Attempt WebSocket upgrade on the plain service
+ ExtendableStringBuilder<512> Request;
+ Request << "GET /plain/ws HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ // Should NOT get 101 — should fall through to normal request handling
+ CHECK(Response.find("101") == std::string::npos);
+
+ Sock.close();
+ }
+
+ SUBCASE("ping/pong auto-response")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a ping frame with payload "test"
+ std::string_view PingPayload = "test";
+ std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size());
+ std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData);
+ asio::write(Sock, asio::buffer(PingFrame));
+
+ // Should receive a pong with the same payload
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong);
+ CHECK_EQ(Reply.Payload.size(), 4u);
+ std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(PongText, "test"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple messages in sequence")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ std::string Msg = fmt::format("message {}", i);
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg);
+ asio::write(Sock, asio::buffer(Frame));
+
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, Msg);
+ }
+
+ CHECK_EQ(TestService.m_MessageCount.load(), 10);
+
+ Sock.close();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: HttpWsClient
+//
+
+namespace {
+
+ struct TestWsClientHandler : public IWsClientHandler
+ {
+ void OnWsOpen() override { m_OpenCount.fetch_add(1); }
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+ }
+ m_MessageCount.fetch_add(1);
+ }
+
+ void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+ };
+
+} // anonymous namespace
+
+TEST_CASE("websocket.client")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7576, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello from client");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello from client");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ // The server echoes the close frame, which triggers OnWsClose on the client side
+ // with the server's close code. Allow the connection to settle.
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+
+ SUBCASE("connect to bad port")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = "ws://127.0.0.1:1/wstest/ws";
+
+ HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)});
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1006);
+ CHECK_EQ(Handler.m_OpenCount.load(), 0);
+ }
+
+ SUBCASE("server-initiated close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+
+ // Copy connections then close them outside the lock to avoid deadlocking
+ // with OnWebSocketClose which acquires an exclusive lock
+ std::vector<Ref<WebSocketConnection>> Conns;
+ TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; });
+ for (auto& Conn : Conns)
+ {
+ Conn->Close(1001, "going away");
+ }
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1001);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
+TEST_SUITE_END();
+
+void
+websocket_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp
index 9135d5425..489324aba 100644
--- a/src/zenhttp/transports/dlltransport.cpp
+++ b/src/zenhttp/transports/dlltransport.cpp
@@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa
void
DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message)
{
- logging::level::LogLevel Level;
- // clang-format off
switch (PluginLogLevel)
{
- case LogLevel::Trace: Level = logging::level::Trace; break;
- case LogLevel::Debug: Level = logging::level::Debug; break;
- case LogLevel::Info: Level = logging::level::Info; break;
- case LogLevel::Warn: Level = logging::level::Warn; break;
- case LogLevel::Err: Level = logging::level::Err; break;
- case LogLevel::Critical: Level = logging::level::Critical; break;
- default: Level = logging::level::Off; break;
+ case LogLevel::Trace:
+ ZEN_TRACE("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Debug:
+ ZEN_DEBUG("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Info:
+ ZEN_INFO("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Warn:
+ ZEN_WARN("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Err:
+ ZEN_ERROR("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Critical:
+ ZEN_CRITICAL("[{}] {}", m_PluginName, Message);
+ return;
+
+ default:
+ ZEN_UNUSED(Message);
+ break;
}
- // clang-format on
- ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message)
}
uint32_t
diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp
index c06a50c95..0217ed44e 100644
--- a/src/zenhttp/transports/winsocktransport.cpp
+++ b/src/zenhttp/transports/winsocktransport.cpp
@@ -322,7 +322,7 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface)
else
{
}
- } while (!IsApplicationExitRequested() && m_KeepRunning.test());
+ } while (m_KeepRunning.test());
ZEN_INFO("HTTP plugin server accept thread exit");
});
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 78876d21b..e8f87b668 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -6,6 +6,7 @@ target('zenhttp')
add_headerfiles("**.h")
add_files("**.cpp")
add_files("servers/httpsys.cpp", {unity_ignored=true})
+ add_files("servers/wshttpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr")
add_packages("http_parser", "json11")
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index a2679f92e..3ac8eea8d 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -7,6 +7,7 @@
# include <zenhttp/httpclient.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/packageformat.h>
+# include <zenhttp/security/passwordsecurity.h>
namespace zen {
@@ -15,7 +16,10 @@ zenhttp_forcelinktests()
{
http_forcelink();
httpclient_forcelink();
+ httpclient_test_forcelink();
forcelink_packageformat();
+ passwordsecurity_forcelink();
+ websocket_forcelink();
}
} // namespace zen
diff --git a/src/zennet-test/zennet-test.cpp b/src/zennet-test/zennet-test.cpp
index bc3b8e8e9..1283eb820 100644
--- a/src/zennet-test/zennet-test.cpp
+++ b/src/zennet-test/zennet-test.cpp
@@ -1,45 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/trace.h>
+#include <zencore/testing.h>
#include <zennet/zennet.h>
#include <zencore/memory/newdelete.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
-
int
main([[maybe_unused]] int argc, [[maybe_unused]] char** argv)
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zennet_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zennet-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
-
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zennet-test", zen::zennet_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zennet/beacon.cpp b/src/zennet/beacon.cpp
new file mode 100644
index 000000000..394a4afbb
--- /dev/null
+++ b/src/zennet/beacon.cpp
@@ -0,0 +1,170 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zennet/beacon.h>
+
+#include <zencore/basicfile.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinaryfile.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/session.h>
+#include <zencore/uid.h>
+
+#include <fmt/format.h>
+#include <asio.hpp>
+#include <map>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct FsBeacon::Impl
+{
+ Impl(std::filesystem::path ShareRoot);
+ ~Impl();
+
+ void EnsureValid();
+
+ void AddGroup(std::string_view GroupId, CbObject Metadata);
+ void ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions);
+ void ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata);
+
+private:
+ std::filesystem::path m_ShareRoot;
+ zen::Oid m_SessionId;
+
+ struct GroupData
+ {
+ CbObject Metadata;
+ BasicFile LockFile;
+ };
+
+ std::map<std::string, GroupData> m_Registration;
+
+ std::filesystem::path GetSessionMarkerPath(std::string_view GroupId, const Oid& SessionId)
+ {
+ Oid::String_t SessionIdString;
+ SessionId.ToString(SessionIdString);
+
+ return m_ShareRoot / GroupId / SessionIdString;
+ }
+};
+
+FsBeacon::Impl::Impl(std::filesystem::path ShareRoot) : m_ShareRoot(ShareRoot), m_SessionId(GetSessionId())
+{
+}
+
+FsBeacon::Impl::~Impl()
+{
+}
+
+void
+FsBeacon::Impl::EnsureValid()
+{
+}
+
+void
+FsBeacon::Impl::AddGroup(std::string_view GroupId, CbObject Metadata)
+{
+ zen::CreateDirectories(m_ShareRoot / GroupId);
+ std::filesystem::path MarkerFile = GetSessionMarkerPath(GroupId, m_SessionId);
+
+ GroupData& Group = m_Registration[std::string(GroupId)];
+
+ Group.Metadata = Metadata;
+
+ std::error_code Ec;
+ Group.LockFile.Open(MarkerFile,
+ BasicFile::Mode::kTruncate | BasicFile::Mode::kPreventDelete |
+ BasicFile::Mode::kPreventWrite /* | BasicFile::Mode::kDeleteOnClose */,
+ Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("failed to open beacon marker file '{}' for write", MarkerFile));
+ }
+
+ Group.LockFile.WriteAll(Metadata.GetBuffer().AsIoBuffer(), Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("failed to write to beacon marker file '{}'", MarkerFile));
+ }
+
+ Group.LockFile.Flush();
+}
+
+void
+FsBeacon::Impl::ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions)
+{
+ DirectoryContent Dc;
+ zen::GetDirectoryContent(m_ShareRoot / GroupId, zen::DirectoryContentFlags::IncludeFiles, /* out */ Dc);
+
+ for (const std::filesystem::path& FilePath : Dc.Files)
+ {
+ std::filesystem::path File = FilePath.filename();
+
+ std::error_code Ec;
+ if (std::filesystem::remove(FilePath, Ec) == false)
+ {
+ auto FileString = File.generic_string();
+
+ if (FileString.length() != Oid::StringLength)
+ continue;
+
+ if (const Oid SessionId = Oid::FromHexString(FileString))
+ {
+ if (std::filesystem::file_size(File, Ec) > 0)
+ {
+ OutSessions.push_back(SessionId);
+ }
+ }
+ }
+ }
+}
+
+void
+FsBeacon::Impl::ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata)
+{
+ for (const Oid& SessionId : InSessions)
+ {
+ const std::filesystem::path MarkerFile = GetSessionMarkerPath(GroupId, SessionId);
+
+ if (CbObject Metadata = LoadCompactBinaryObject(MarkerFile).Object)
+ {
+ OutMetadata.push_back(std::move(Metadata));
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+FsBeacon::FsBeacon(std::filesystem::path ShareRoot) : m_Impl(std::make_unique<Impl>(ShareRoot))
+{
+}
+
+FsBeacon::~FsBeacon()
+{
+}
+
+void
+FsBeacon::AddGroup(std::string_view GroupId, CbObject Metadata)
+{
+ m_Impl->AddGroup(GroupId, Metadata);
+}
+
+void
+FsBeacon::ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions)
+{
+ m_Impl->ScanGroup(GroupId, OutSessions);
+}
+
+void
+FsBeacon::ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata)
+{
+ m_Impl->ReadMetadata(GroupId, InSessions, OutMetadata);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+} // namespace zen
diff --git a/src/zennet/include/zennet/beacon.h b/src/zennet/include/zennet/beacon.h
new file mode 100644
index 000000000..a8d4805cb
--- /dev/null
+++ b/src/zennet/include/zennet/beacon.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zennet/zennet.h>
+
+#include <zencore/uid.h>
+
+#include <filesystem>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace zen {
+
+class CbObject;
+
+/** File-system based peer discovery
+
+ Intended to be used with an SMB file share as the root.
+ */
+
+class FsBeacon
+{
+public:
+ FsBeacon(std::filesystem::path ShareRoot);
+ ~FsBeacon();
+
+ void AddGroup(std::string_view GroupId, CbObject Metadata);
+ void ScanGroup(std::string_view GroupId, std::vector<Oid>& OutSessions);
+ void ReadMetadata(std::string_view GroupId, const std::vector<Oid>& InSessions, std::vector<CbObject>& OutMetadata);
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zennet/include/zennet/statsdclient.h b/src/zennet/include/zennet/statsdclient.h
index c378e49ce..7688c132c 100644
--- a/src/zennet/include/zennet/statsdclient.h
+++ b/src/zennet/include/zennet/statsdclient.h
@@ -8,6 +8,8 @@
#include <memory>
#include <string_view>
+#undef SendMessage
+
namespace zen {
class StatsTransportBase
diff --git a/src/zennet/statsdclient.cpp b/src/zennet/statsdclient.cpp
index fe5ca4dda..8afa2e835 100644
--- a/src/zennet/statsdclient.cpp
+++ b/src/zennet/statsdclient.cpp
@@ -12,6 +12,7 @@
ZEN_THIRD_PARTY_INCLUDES_START
#include <zencore/windows.h>
#include <asio.hpp>
+#undef SendMessage
ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
@@ -379,6 +380,8 @@ statsd_forcelink()
{
}
+TEST_SUITE_BEGIN("net.statsdclient");
+
TEST_CASE("zennet.statsd.emit")
{
// auto Client = CreateStatsDaemonClient("localhost", 8125);
@@ -458,6 +461,8 @@ TEST_CASE("zennet.statsd.batch")
}
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h
new file mode 100644
index 000000000..0a3411ace
--- /dev/null
+++ b/src/zennomad/include/zennomad/nomadclient.h
@@ -0,0 +1,77 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zennomad/nomadconfig.h>
+
+#include <zencore/logbase.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace zen {
+class HttpClient;
+}
+
+namespace zen::nomad {
+
+/** Summary of a Nomad job returned by the API. */
+struct NomadJobInfo
+{
+ std::string Id;
+ std::string Status; ///< "pending", "running", "dead"
+ std::string StatusDescription;
+};
+
+/** Summary of a Nomad allocation returned by the API. */
+struct NomadAllocInfo
+{
+ std::string Id;
+ std::string ClientStatus; ///< "pending", "running", "complete", "failed"
+ std::string TaskState; ///< State of the task within the allocation
+};
+
+/** HTTP client for the Nomad REST API (v1).
+ *
+ * Handles job submission, status polling, and job termination.
+ * All calls are synchronous. Thread safety: individual methods are
+ * not thread-safe; callers must synchronize access.
+ */
+class NomadClient
+{
+public:
+ explicit NomadClient(const NomadConfig& Config);
+ ~NomadClient();
+
+ NomadClient(const NomadClient&) = delete;
+ NomadClient& operator=(const NomadClient&) = delete;
+
+ /** Initialize the underlying HTTP client. Must be called before other methods. */
+ bool Initialize();
+
+ /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint.
+ * The JSON structure varies based on the configured driver and distribution mode. */
+ std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const;
+
+ /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */
+ bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob);
+
+ /** Get the status of a job via GET /v1/job/{jobId}. */
+ bool GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob);
+
+ /** Get allocations for a job via GET /v1/job/{jobId}/allocations. */
+ bool GetAllocations(const std::string& JobId, std::vector<NomadAllocInfo>& OutAllocs);
+
+ /** Stop a job via DELETE /v1/job/{jobId}. */
+ bool StopJob(const std::string& JobId);
+
+ LoggerRef Log() { return m_Log; }
+
+private:
+ NomadConfig m_Config;
+ std::unique_ptr<zen::HttpClient> m_Http;
+ LoggerRef m_Log;
+};
+
+} // namespace zen::nomad
diff --git a/src/zennomad/include/zennomad/nomadconfig.h b/src/zennomad/include/zennomad/nomadconfig.h
new file mode 100644
index 000000000..92d2bbaca
--- /dev/null
+++ b/src/zennomad/include/zennomad/nomadconfig.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zennomad/zennomad.h>
+
+#include <string>
+
+namespace zen::nomad {
+
+/** Nomad task driver type. */
+enum class Driver
+{
+ RawExec, ///< Use Nomad raw_exec driver (direct process execution)
+ Docker, ///< Use Nomad Docker driver
+};
+
+/** How the zenserver binary is made available on Nomad clients. */
+enum class BinaryDistribution
+{
+ PreDeployed, ///< Binary is already present on Nomad client nodes
+ Artifact, ///< Download binary via Nomad artifact stanza
+};
+
+/** Configuration for Nomad worker provisioning.
+ *
+ * Specifies the Nomad server URL, authentication, resource limits, and
+ * job configuration. Used by NomadClient and NomadProvisioner.
+ */
+struct NomadConfig
+{
+ bool Enabled = false; ///< Whether Nomad provisioning is active
+ std::string ServerUrl; ///< Nomad HTTP API URL (e.g. "http://localhost:4646")
+ std::string AclToken; ///< Nomad ACL token (sent as X-Nomad-Token header)
+ std::string Datacenter = "dc1"; ///< Target datacenter
+ std::string Namespace = "default"; ///< Nomad namespace
+ std::string Region; ///< Nomad region (empty = server default)
+
+ Driver TaskDriver = Driver::RawExec; ///< Task driver for job execution
+ BinaryDistribution BinDistribution = BinaryDistribution::PreDeployed; ///< How to distribute the zenserver binary
+
+ std::string BinaryPath; ///< Path to zenserver on Nomad clients (PreDeployed mode)
+ std::string ArtifactSource; ///< URL to download zenserver binary (Artifact mode)
+ std::string DockerImage; ///< Docker image name (Docker driver mode)
+
+ int MaxJobs = 64; ///< Maximum concurrent Nomad jobs
+ int CpuMhz = 1000; ///< CPU MHz allocated per task
+ int MemoryMb = 2048; ///< Memory MB allocated per task
+ int CoresPerJob = 32; ///< Estimated cores per job (for scaling calculations)
+ int MaxCores = 2048; ///< Maximum total cores to provision
+
+ std::string JobPrefix = "zenserver-worker"; ///< Prefix for generated Nomad job IDs
+
+ /** Validate the configuration. Returns false if required fields are missing
+ * or incompatible options are set. */
+ bool Validate() const;
+};
+
+const char* ToString(Driver D);
+const char* ToString(BinaryDistribution Dist);
+
+bool FromString(Driver& OutDriver, std::string_view Str);
+bool FromString(BinaryDistribution& OutDist, std::string_view Str);
+
+} // namespace zen::nomad
diff --git a/src/zennomad/include/zennomad/nomadprocess.h b/src/zennomad/include/zennomad/nomadprocess.h
new file mode 100644
index 000000000..a66c2ce41
--- /dev/null
+++ b/src/zennomad/include/zennomad/nomadprocess.h
@@ -0,0 +1,78 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpclient.h>
+
+#include <memory>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen::nomad {
+
+struct NomadJobInfo;
+struct NomadAllocInfo;
+
+/** Manages a Nomad agent process running in dev mode for testing.
+ *
+ * Spawns `nomad agent -dev` and polls the HTTP API until the agent
+ * is ready. On destruction or via StopNomadAgent(), the agent
+ * process is killed.
+ */
+class NomadProcess
+{
+public:
+ NomadProcess();
+ ~NomadProcess();
+
+ NomadProcess(const NomadProcess&) = delete;
+ NomadProcess& operator=(const NomadProcess&) = delete;
+
+ /** Spawn a Nomad dev agent and block until the leader endpoint responds (10 s timeout). */
+ void SpawnNomadAgent();
+
+ /** Kill the Nomad agent process. */
+ void StopNomadAgent();
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+/** Lightweight HTTP wrapper around the Nomad v1 REST API for use in tests.
+ *
+ * Unlike the production NomadClient (which requires a NomadConfig and
+ * supports all driver/distribution modes), this client exposes a simpler
+ * interface geared towards test scenarios.
+ */
+class NomadTestClient
+{
+public:
+ explicit NomadTestClient(std::string_view BaseUri);
+ ~NomadTestClient();
+
+ NomadTestClient(const NomadTestClient&) = delete;
+ NomadTestClient& operator=(const NomadTestClient&) = delete;
+
+ /** Submit a raw_exec batch job.
+ * Returns the parsed job info on success; Id will be empty on failure. */
+ NomadJobInfo SubmitJob(std::string_view JobId, std::string_view Command, const std::vector<std::string>& Args);
+
+ /** Query the status of an existing job. */
+ NomadJobInfo GetJobStatus(std::string_view JobId);
+
+ /** Stop (deregister) a running job. */
+ void StopJob(std::string_view JobId);
+
+ /** Get allocations for a job. */
+ std::vector<NomadAllocInfo> GetAllocations(std::string_view JobId);
+
+ /** List all jobs, optionally filtered by prefix. */
+ std::vector<NomadJobInfo> ListJobs(std::string_view Prefix = "");
+
+private:
+ HttpClient m_HttpClient;
+};
+
+} // namespace zen::nomad
diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h
new file mode 100644
index 000000000..750693b3f
--- /dev/null
+++ b/src/zennomad/include/zennomad/nomadprovisioner.h
@@ -0,0 +1,107 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zennomad/nomadconfig.h>
+
+#include <zencore/logbase.h>
+
+#include <atomic>
+#include <condition_variable>
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <string>
+#include <thread>
+#include <vector>
+
+namespace zen::nomad {
+
+class NomadClient;
+
+/** Snapshot of the current Nomad provisioning state, returned by NomadProvisioner::GetStats(). */
+struct NomadProvisioningStats
+{
+ uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores)
+ uint32_t EstimatedCoreCount = 0; ///< Cores expected from submitted jobs
+ uint32_t ActiveJobCount = 0; ///< Number of currently tracked Nomad jobs
+ uint32_t RunningJobCount = 0; ///< Number of jobs in "running" status
+};
+
+/** Job lifecycle manager for Nomad worker provisioning.
+ *
+ * Provisions remote compute workers by submitting batch jobs to a Nomad
+ * cluster via the REST API. Each job runs zenserver in compute mode, which
+ * announces itself back to the orchestrator.
+ *
+ * Uses a single management thread that periodically:
+ * 1. Submits new jobs when estimated cores < target cores
+ * 2. Polls existing jobs for status changes
+ * 3. Cleans up dead/failed jobs and adjusts counters
+ *
+ * Thread safety: SetTargetCoreCount and GetStats may be called from any thread.
+ */
+class NomadProvisioner
+{
+public:
+ /** Construct a provisioner.
+ * @param Config Nomad connection and job configuration.
+ * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */
+ NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint);
+
+ /** Signals the management thread to exit and stops all tracked jobs. */
+ ~NomadProvisioner();
+
+ NomadProvisioner(const NomadProvisioner&) = delete;
+ NomadProvisioner& operator=(const NomadProvisioner&) = delete;
+
+ /** Set the target number of cores to provision.
+ * Clamped to NomadConfig::MaxCores. The management thread will
+ * submit new jobs to approach this target. */
+ void SetTargetCoreCount(uint32_t Count);
+
+ /** Return a snapshot of the current provisioning counters. */
+ NomadProvisioningStats GetStats() const;
+
+private:
+ LoggerRef Log() { return m_Log; }
+
+ struct TrackedJob
+ {
+ std::string JobId;
+ std::string Status; ///< "pending", "running", "dead"
+ int Cores = 0;
+ };
+
+ void ManagementThread();
+ void SubmitNewJobs();
+ void PollExistingJobs();
+ void CleanupDeadJobs();
+ void StopAllJobs();
+
+ std::string GenerateJobId();
+
+ NomadConfig m_Config;
+ std::string m_OrchestratorEndpoint;
+
+ std::unique_ptr<NomadClient> m_Client;
+
+ mutable std::mutex m_JobsLock;
+ std::vector<TrackedJob> m_Jobs;
+ std::atomic<uint32_t> m_JobIndex{0};
+
+ std::atomic<uint32_t> m_TargetCoreCount{0};
+ std::atomic<uint32_t> m_EstimatedCoreCount{0};
+ std::atomic<uint32_t> m_RunningJobCount{0};
+
+ std::thread m_Thread;
+ std::mutex m_WakeMutex;
+ std::condition_variable m_WakeCV;
+ std::atomic<bool> m_ShouldExit{false};
+
+ uint32_t m_ProcessId = 0;
+
+ LoggerRef m_Log;
+};
+
+} // namespace zen::nomad
diff --git a/src/zennomad/include/zennomad/zennomad.h b/src/zennomad/include/zennomad/zennomad.h
new file mode 100644
index 000000000..09fb98dfe
--- /dev/null
+++ b/src/zennomad/include/zennomad/zennomad.h
@@ -0,0 +1,9 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if !defined(ZEN_WITH_NOMAD)
+# define ZEN_WITH_NOMAD 1
+#endif
diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp
new file mode 100644
index 000000000..9edcde125
--- /dev/null
+++ b/src/zennomad/nomadclient.cpp
@@ -0,0 +1,366 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/memoryview.h>
+#include <zencore/trace.h>
+#include <zenhttp/httpclient.h>
+#include <zennomad/nomadclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::nomad {
+
+namespace {
+
+ HttpClient::KeyValueMap MakeNomadHeaders(const NomadConfig& Config)
+ {
+ HttpClient::KeyValueMap Headers;
+ if (!Config.AclToken.empty())
+ {
+ Headers->emplace("X-Nomad-Token", Config.AclToken);
+ }
+ return Headers;
+ }
+
+} // namespace
+
+NomadClient::NomadClient(const NomadConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("nomad.client"))
+{
+}
+
+NomadClient::~NomadClient() = default;
+
+bool
+NomadClient::Initialize()
+{
+ ZEN_TRACE_CPU("NomadClient::Initialize");
+
+ HttpClientSettings Settings;
+ Settings.LogCategory = "nomad.http";
+ Settings.ConnectTimeout = std::chrono::milliseconds{10000};
+ Settings.Timeout = std::chrono::milliseconds{60000};
+ Settings.RetryCount = 1;
+
+ // Ensure the base URL ends with a slash so path concatenation works correctly
+ std::string BaseUrl = m_Config.ServerUrl;
+ if (!BaseUrl.empty() && BaseUrl.back() != '/')
+ {
+ BaseUrl += '/';
+ }
+
+ m_Http = std::make_unique<zen::HttpClient>(BaseUrl, Settings);
+
+ return true;
+}
+
+std::string
+NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const
+{
+ ZEN_TRACE_CPU("NomadClient::BuildJobJson");
+
+ // Build the task config based on driver and distribution mode
+ json11::Json::object TaskConfig;
+
+ if (m_Config.TaskDriver == Driver::RawExec)
+ {
+ std::string Command;
+ if (m_Config.BinDistribution == BinaryDistribution::PreDeployed)
+ {
+ Command = m_Config.BinaryPath;
+ }
+ else
+ {
+ // Artifact mode: binary is downloaded to local/zenserver
+ Command = "local/zenserver";
+ }
+
+ TaskConfig["command"] = Command;
+
+ json11::Json::array Args;
+ Args.push_back("compute");
+ Args.push_back("--http=asio");
+ if (!OrchestratorEndpoint.empty())
+ {
+ ExtendableStringBuilder<256> CoordArg;
+ CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint;
+ Args.push_back(std::string(CoordArg.ToView()));
+ }
+ {
+ ExtendableStringBuilder<128> IdArg;
+ IdArg << "--instance-id=nomad-" << JobId;
+ Args.push_back(std::string(IdArg.ToView()));
+ }
+ TaskConfig["args"] = Args;
+ }
+ else
+ {
+ // Docker driver
+ TaskConfig["image"] = m_Config.DockerImage;
+
+ json11::Json::array Args;
+ Args.push_back("compute");
+ Args.push_back("--http=asio");
+ if (!OrchestratorEndpoint.empty())
+ {
+ ExtendableStringBuilder<256> CoordArg;
+ CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint;
+ Args.push_back(std::string(CoordArg.ToView()));
+ }
+ {
+ ExtendableStringBuilder<128> IdArg;
+ IdArg << "--instance-id=nomad-" << JobId;
+ Args.push_back(std::string(IdArg.ToView()));
+ }
+ TaskConfig["args"] = Args;
+ }
+
+ // Build resource stanza
+ json11::Json::object Resources;
+ Resources["CPU"] = m_Config.CpuMhz;
+ Resources["MemoryMB"] = m_Config.MemoryMb;
+
+ // Build the task
+ json11::Json::object Task;
+ Task["Name"] = "zenserver";
+ Task["Driver"] = (m_Config.TaskDriver == Driver::RawExec) ? "raw_exec" : "docker";
+ Task["Config"] = TaskConfig;
+ Task["Resources"] = Resources;
+
+ // Add artifact stanza if using artifact distribution
+ if (m_Config.BinDistribution == BinaryDistribution::Artifact && !m_Config.ArtifactSource.empty())
+ {
+ json11::Json::object Artifact;
+ Artifact["GetterSource"] = m_Config.ArtifactSource;
+
+ json11::Json::array Artifacts;
+ Artifacts.push_back(Artifact);
+ Task["Artifacts"] = Artifacts;
+ }
+
+ json11::Json::array Tasks;
+ Tasks.push_back(Task);
+
+ // Build the task group
+ json11::Json::object Group;
+ Group["Name"] = "zenserver-group";
+ Group["Count"] = 1;
+ Group["Tasks"] = Tasks;
+
+ json11::Json::array Groups;
+ Groups.push_back(Group);
+
+ // Build datacenters array
+ json11::Json::array Datacenters;
+ Datacenters.push_back(m_Config.Datacenter);
+
+ // Build the job
+ json11::Json::object Job;
+ Job["ID"] = JobId;
+ Job["Name"] = JobId;
+ Job["Type"] = "batch";
+ Job["Datacenters"] = Datacenters;
+ Job["TaskGroups"] = Groups;
+
+ if (!m_Config.Namespace.empty() && m_Config.Namespace != "default")
+ {
+ Job["Namespace"] = m_Config.Namespace;
+ }
+
+ if (!m_Config.Region.empty())
+ {
+ Job["Region"] = m_Config.Region;
+ }
+
+ // Wrap in the registration envelope
+ json11::Json::object Root;
+ Root["Job"] = Job;
+
+ return json11::Json(Root).dump();
+}
+
+bool
+NomadClient::SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob)
+{
+ ZEN_TRACE_CPU("NomadClient::SubmitJob");
+
+ const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{JobJson.data(), JobJson.size()}, ZenContentType::kJSON);
+
+ const HttpClient::Response Response = m_Http->Put("v1/jobs", Payload, MakeNomadHeaders(m_Config));
+
+ if (Response.Error)
+ {
+ ZEN_WARN("Nomad job submit failed: {}", Response.Error->ErrorMessage);
+ return false;
+ }
+
+ const int StatusCode = static_cast<int>(Response.StatusCode);
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("Nomad job submit failed with HTTP/{}", StatusCode);
+ return false;
+ }
+
+ const std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("invalid JSON response from Nomad job submit: {}", Err);
+ return false;
+ }
+
+ // The response contains EvalID; the job ID is what we submitted
+ OutJob.Id = Json["JobModifyIndex"].is_number() ? OutJob.Id : "";
+ OutJob.Status = "pending";
+
+ ZEN_INFO("Nomad job submitted: eval_id={}", Json["EvalID"].string_value());
+
+ return true;
+}
+
+bool
+NomadClient::GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob)
+{
+ ZEN_TRACE_CPU("NomadClient::GetJobStatus");
+
+ ExtendableStringBuilder<128> Path;
+ Path << "v1/job/" << JobId;
+
+ const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config));
+
+ if (Response.Error)
+ {
+ ZEN_WARN("Nomad job status query failed for '{}': {}", JobId, Response.Error->ErrorMessage);
+ return false;
+ }
+
+ const int StatusCode = static_cast<int>(Response.StatusCode);
+
+ if (StatusCode == 404)
+ {
+ ZEN_INFO("Nomad job '{}' not found", JobId);
+ OutJob.Status = "dead";
+ return true;
+ }
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("Nomad job status query failed with HTTP/{}", StatusCode);
+ return false;
+ }
+
+ const std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("invalid JSON in Nomad job status response: {}", Err);
+ return false;
+ }
+
+ OutJob.Id = Json["ID"].string_value();
+ OutJob.Status = Json["Status"].string_value();
+ if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string())
+ {
+ OutJob.StatusDescription = Desc.string_value();
+ }
+
+ return true;
+}
+
+bool
+NomadClient::GetAllocations(const std::string& JobId, std::vector<NomadAllocInfo>& OutAllocs)
+{
+ ZEN_TRACE_CPU("NomadClient::GetAllocations");
+
+ ExtendableStringBuilder<128> Path;
+ Path << "v1/job/" << JobId << "/allocations";
+
+ const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config));
+
+ if (Response.Error)
+ {
+ ZEN_WARN("Nomad allocation query failed for '{}': {}", JobId, Response.Error->ErrorMessage);
+ return false;
+ }
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("Nomad allocation query failed with HTTP/{}", static_cast<int>(Response.StatusCode));
+ return false;
+ }
+
+ const std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("invalid JSON in Nomad allocation response: {}", Err);
+ return false;
+ }
+
+ OutAllocs.clear();
+ if (!Json.is_array())
+ {
+ return true;
+ }
+
+ for (const json11::Json& AllocVal : Json.array_items())
+ {
+ NomadAllocInfo Alloc;
+ Alloc.Id = AllocVal["ID"].string_value();
+ Alloc.ClientStatus = AllocVal["ClientStatus"].string_value();
+
+ // Extract task state if available
+ if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object())
+ {
+ for (const auto& [TaskName, TaskState] : TaskStates.object_items())
+ {
+ if (TaskState["State"].is_string())
+ {
+ Alloc.TaskState = TaskState["State"].string_value();
+ }
+ }
+ }
+
+ OutAllocs.push_back(std::move(Alloc));
+ }
+
+ return true;
+}
+
+bool
+NomadClient::StopJob(const std::string& JobId)
+{
+ ZEN_TRACE_CPU("NomadClient::StopJob");
+
+ ExtendableStringBuilder<128> Path;
+ Path << "v1/job/" << JobId;
+
+ const HttpClient::Response Response = m_Http->Delete(Path.ToView(), MakeNomadHeaders(m_Config));
+
+ if (Response.Error)
+ {
+ ZEN_WARN("Nomad job stop failed for '{}': {}", JobId, Response.Error->ErrorMessage);
+ return false;
+ }
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("Nomad job stop failed with HTTP/{}", static_cast<int>(Response.StatusCode));
+ return false;
+ }
+
+ ZEN_INFO("Nomad job '{}' stopped", JobId);
+ return true;
+}
+
+} // namespace zen::nomad
diff --git a/src/zennomad/nomadconfig.cpp b/src/zennomad/nomadconfig.cpp
new file mode 100644
index 000000000..d55b3da9a
--- /dev/null
+++ b/src/zennomad/nomadconfig.cpp
@@ -0,0 +1,91 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zennomad/nomadconfig.h>
+
+namespace zen::nomad {
+
+bool
+NomadConfig::Validate() const
+{
+ if (ServerUrl.empty())
+ {
+ return false;
+ }
+
+ if (BinDistribution == BinaryDistribution::PreDeployed && BinaryPath.empty())
+ {
+ return false;
+ }
+
+ if (BinDistribution == BinaryDistribution::Artifact && ArtifactSource.empty())
+ {
+ return false;
+ }
+
+ if (TaskDriver == Driver::Docker && DockerImage.empty())
+ {
+ return false;
+ }
+
+ return true;
+}
+
+const char*
+ToString(Driver D)
+{
+ switch (D)
+ {
+ case Driver::RawExec:
+ return "raw_exec";
+ case Driver::Docker:
+ return "docker";
+ }
+ return "raw_exec";
+}
+
+const char*
+ToString(BinaryDistribution Dist)
+{
+ switch (Dist)
+ {
+ case BinaryDistribution::PreDeployed:
+ return "predeployed";
+ case BinaryDistribution::Artifact:
+ return "artifact";
+ }
+ return "predeployed";
+}
+
+bool
+FromString(Driver& OutDriver, std::string_view Str)
+{
+ if (Str == "raw_exec")
+ {
+ OutDriver = Driver::RawExec;
+ return true;
+ }
+ if (Str == "docker")
+ {
+ OutDriver = Driver::Docker;
+ return true;
+ }
+ return false;
+}
+
+bool
+FromString(BinaryDistribution& OutDist, std::string_view Str)
+{
+ if (Str == "predeployed")
+ {
+ OutDist = BinaryDistribution::PreDeployed;
+ return true;
+ }
+ if (Str == "artifact")
+ {
+ OutDist = BinaryDistribution::Artifact;
+ return true;
+ }
+ return false;
+}
+
+} // namespace zen::nomad
diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp
new file mode 100644
index 000000000..1ae968fb7
--- /dev/null
+++ b/src/zennomad/nomadprocess.cpp
@@ -0,0 +1,354 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zennomad/nomadclient.h>
+#include <zennomad/nomadprocess.h>
+
+#include <zenbase/zenbase.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/memoryview.h>
+#include <zencore/process.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <fmt/format.h>
+
+namespace zen::nomad {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct NomadProcess::Impl
+{
+ Impl(std::string_view BaseUri) : m_HttpClient(BaseUri) {}
+ ~Impl() = default;
+
+ void SpawnNomadAgent()
+ {
+ ZEN_TRACE_CPU("SpawnNomadAgent");
+
+ if (m_ProcessHandle.IsValid())
+ {
+ return;
+ }
+
+ CreateProcOptions Options;
+ Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+
+ CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options);
+
+ if (Result)
+ {
+ m_ProcessHandle.Initialize(Result);
+
+ Stopwatch Timer;
+
+ // Poll to check when the agent is ready
+
+ do
+ {
+ Sleep(100);
+ HttpClient::Response Resp = m_HttpClient.Get("v1/status/leader");
+ if (Resp)
+ {
+ ZEN_INFO("Nomad agent started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+
+ return;
+ }
+ } while (Timer.GetElapsedTimeMs() < 30000);
+ }
+
+ // Report failure!
+
+ ZEN_WARN("Nomad agent failed to start within timeout period");
+ }
+
+ void StopNomadAgent()
+ {
+ if (!m_ProcessHandle.IsValid())
+ {
+ return;
+ }
+
+ // This waits for the process to exit and also resets the handle
+ m_ProcessHandle.Kill();
+ }
+
+private:
+ ProcessHandle m_ProcessHandle;
+ HttpClient m_HttpClient;
+};
+
+NomadProcess::NomadProcess() : m_Impl(std::make_unique<Impl>("http://localhost:4646/"))
+{
+}
+
+NomadProcess::~NomadProcess()
+{
+}
+
+void
+NomadProcess::SpawnNomadAgent()
+{
+ m_Impl->SpawnNomadAgent();
+}
+
+void
+NomadProcess::StopNomadAgent()
+{
+ m_Impl->StopNomadAgent();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+NomadTestClient::NomadTestClient(std::string_view BaseUri) : m_HttpClient(BaseUri)
+{
+}
+
+NomadTestClient::~NomadTestClient()
+{
+}
+
+NomadJobInfo
+NomadTestClient::SubmitJob(std::string_view JobId, std::string_view Command, const std::vector<std::string>& Args)
+{
+ ZEN_TRACE_CPU("SubmitNomadJob");
+
+ NomadJobInfo Result;
+
+ // Build the job JSON for a raw_exec batch job
+ json11::Json::object TaskConfig;
+ TaskConfig["command"] = std::string(Command);
+
+ json11::Json::array JsonArgs;
+ for (const auto& Arg : Args)
+ {
+ JsonArgs.push_back(Arg);
+ }
+ TaskConfig["args"] = JsonArgs;
+
+ json11::Json::object Resources;
+ Resources["CPU"] = 100;
+ Resources["MemoryMB"] = 64;
+
+ json11::Json::object Task;
+ Task["Name"] = "test-task";
+ Task["Driver"] = "raw_exec";
+ Task["Config"] = TaskConfig;
+ Task["Resources"] = Resources;
+
+ json11::Json::array Tasks;
+ Tasks.push_back(Task);
+
+ json11::Json::object Group;
+ Group["Name"] = "test-group";
+ Group["Count"] = 1;
+ Group["Tasks"] = Tasks;
+
+ json11::Json::array Groups;
+ Groups.push_back(Group);
+
+ json11::Json::array Datacenters;
+ Datacenters.push_back("dc1");
+
+ json11::Json::object Job;
+ Job["ID"] = std::string(JobId);
+ Job["Name"] = std::string(JobId);
+ Job["Type"] = "batch";
+ Job["Datacenters"] = Datacenters;
+ Job["TaskGroups"] = Groups;
+
+ json11::Json::object Root;
+ Root["Job"] = Job;
+
+ std::string Body = json11::Json(Root).dump();
+
+ IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{Body.data(), Body.size()}, ZenContentType::kJSON);
+
+ HttpClient::Response Response =
+ m_HttpClient.Put("v1/jobs", Payload, {{"Content-Type", "application/json"}, {"Accept", "application/json"}});
+
+ if (!Response || !Response.IsSuccess())
+ {
+ ZEN_WARN("NomadTestClient: SubmitJob failed for '{}'", JobId);
+ return Result;
+ }
+
+ std::string ResponseBody(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(ResponseBody, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("NomadTestClient: invalid JSON in SubmitJob response: {}", Err);
+ return Result;
+ }
+
+ Result.Id = std::string(JobId);
+ Result.Status = "pending";
+
+ ZEN_INFO("NomadTestClient: job '{}' submitted (eval_id={})", JobId, Json["EvalID"].string_value());
+
+ return Result;
+}
+
+NomadJobInfo
+NomadTestClient::GetJobStatus(std::string_view JobId)
+{
+ ZEN_TRACE_CPU("GetNomadJobStatus");
+
+ NomadJobInfo Result;
+
+ HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}", JobId));
+
+ if (Response.Error)
+ {
+ ZEN_WARN("NomadTestClient: GetJobStatus failed for '{}': {}", JobId, Response.Error->ErrorMessage);
+ return Result;
+ }
+
+ if (static_cast<int>(Response.StatusCode) == 404)
+ {
+ Result.Status = "dead";
+ return Result;
+ }
+
+ if (!Response.IsSuccess())
+ {
+ ZEN_WARN("NomadTestClient: GetJobStatus failed with HTTP/{}", static_cast<int>(Response.StatusCode));
+ return Result;
+ }
+
+ std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty())
+ {
+ ZEN_WARN("NomadTestClient: invalid JSON in GetJobStatus response: {}", Err);
+ return Result;
+ }
+
+ Result.Id = Json["ID"].string_value();
+ Result.Status = Json["Status"].string_value();
+ if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string())
+ {
+ Result.StatusDescription = Desc.string_value();
+ }
+
+ return Result;
+}
+
+void
+NomadTestClient::StopJob(std::string_view JobId)
+{
+ ZEN_TRACE_CPU("StopNomadJob");
+
+ HttpClient::Response Response = m_HttpClient.Delete(fmt::format("v1/job/{}", JobId));
+
+ if (!Response || !Response.IsSuccess())
+ {
+ ZEN_WARN("NomadTestClient: StopJob failed for '{}'", JobId);
+ return;
+ }
+
+ ZEN_INFO("NomadTestClient: job '{}' stopped", JobId);
+}
+
+std::vector<NomadAllocInfo>
+NomadTestClient::GetAllocations(std::string_view JobId)
+{
+ ZEN_TRACE_CPU("GetNomadAllocations");
+
+ std::vector<NomadAllocInfo> Allocs;
+
+ HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}/allocations", JobId));
+
+ if (!Response || !Response.IsSuccess())
+ {
+ ZEN_WARN("NomadTestClient: GetAllocations failed for '{}'", JobId);
+ return Allocs;
+ }
+
+ std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty() || !Json.is_array())
+ {
+ return Allocs;
+ }
+
+ for (const json11::Json& AllocVal : Json.array_items())
+ {
+ NomadAllocInfo Alloc;
+ Alloc.Id = AllocVal["ID"].string_value();
+ Alloc.ClientStatus = AllocVal["ClientStatus"].string_value();
+
+ if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object())
+ {
+ for (const auto& [TaskName, TaskState] : TaskStates.object_items())
+ {
+ if (TaskState["State"].is_string())
+ {
+ Alloc.TaskState = TaskState["State"].string_value();
+ }
+ }
+ }
+
+ Allocs.push_back(std::move(Alloc));
+ }
+
+ return Allocs;
+}
+
+std::vector<NomadJobInfo>
+NomadTestClient::ListJobs(std::string_view Prefix)
+{
+ ZEN_TRACE_CPU("ListNomadJobs");
+
+ std::vector<NomadJobInfo> Jobs;
+
+ std::string Url = "v1/jobs";
+ if (!Prefix.empty())
+ {
+ Url = fmt::format("v1/jobs?prefix={}", Prefix);
+ }
+
+ HttpClient::Response Response = m_HttpClient.Get(Url);
+
+ if (!Response || !Response.IsSuccess())
+ {
+ ZEN_WARN("NomadTestClient: ListJobs failed");
+ return Jobs;
+ }
+
+ std::string Body(Response.AsText());
+ std::string Err;
+ const json11::Json Json = json11::Json::parse(Body, Err);
+
+ if (!Err.empty() || !Json.is_array())
+ {
+ return Jobs;
+ }
+
+ for (const json11::Json& JobVal : Json.array_items())
+ {
+ NomadJobInfo Job;
+ Job.Id = JobVal["ID"].string_value();
+ Job.Status = JobVal["Status"].string_value();
+ if (const json11::Json Desc = JobVal["StatusDescription"]; Desc.is_string())
+ {
+ Job.StatusDescription = Desc.string_value();
+ }
+ Jobs.push_back(std::move(Job));
+ }
+
+ return Jobs;
+}
+
+} // namespace zen::nomad
diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp
new file mode 100644
index 000000000..3fe9c0ac3
--- /dev/null
+++ b/src/zennomad/nomadprovisioner.cpp
@@ -0,0 +1,264 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zennomad/nomadclient.h>
+#include <zennomad/nomadprovisioner.h>
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/process.h>
+#include <zencore/scopeguard.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+#include <chrono>
+
+namespace zen::nomad {
+
+NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint)
+: m_Config(Config)
+, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId()))
+, m_Log(zen::logging::Get("nomad.provisioner"))
+{
+ ZEN_DEBUG("initializing provisioner (server: {}, driver: {}, max_cores: {}, cores_per_job: {}, max_jobs: {})",
+ m_Config.ServerUrl,
+ ToString(m_Config.TaskDriver),
+ m_Config.MaxCores,
+ m_Config.CoresPerJob,
+ m_Config.MaxJobs);
+
+ m_Client = std::make_unique<NomadClient>(m_Config);
+ if (!m_Client->Initialize())
+ {
+ ZEN_ERROR("failed to initialize Nomad HTTP client");
+ return;
+ }
+
+ ZEN_DEBUG("Nomad HTTP client initialized, starting management thread");
+
+ m_Thread = std::thread([this] { ManagementThread(); });
+}
+
+NomadProvisioner::~NomadProvisioner()
+{
+ ZEN_DEBUG("provisioner shutting down");
+
+ m_ShouldExit.store(true);
+ m_WakeCV.notify_all();
+
+ if (m_Thread.joinable())
+ {
+ m_Thread.join();
+ }
+
+ StopAllJobs();
+
+ ZEN_DEBUG("provisioner shutdown complete");
+}
+
+void
+NomadProvisioner::SetTargetCoreCount(uint32_t Count)
+{
+ const uint32_t Clamped = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores));
+ const uint32_t Previous = m_TargetCoreCount.exchange(Clamped);
+
+ if (Clamped != Previous)
+ {
+ ZEN_DEBUG("target core count changed: {} -> {}", Previous, Clamped);
+ }
+
+ m_WakeCV.notify_all();
+}
+
+NomadProvisioningStats
+NomadProvisioner::GetStats() const
+{
+ NomadProvisioningStats Stats;
+ Stats.TargetCoreCount = m_TargetCoreCount.load();
+ Stats.EstimatedCoreCount = m_EstimatedCoreCount.load();
+ Stats.RunningJobCount = m_RunningJobCount.load();
+
+ {
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+ Stats.ActiveJobCount = static_cast<uint32_t>(m_Jobs.size());
+ }
+
+ return Stats;
+}
+
+std::string
+NomadProvisioner::GenerateJobId()
+{
+ const uint32_t Index = m_JobIndex.fetch_add(1);
+
+ ExtendableStringBuilder<128> Builder;
+ Builder << m_Config.JobPrefix << "-" << m_ProcessId << "-" << Index;
+ return std::string(Builder.ToView());
+}
+
+void
+NomadProvisioner::ManagementThread()
+{
+ ZEN_TRACE_CPU("Nomad_Mgmt");
+ zen::SetCurrentThreadName("nomad_mgmt");
+
+ ZEN_INFO("Nomad management thread started");
+
+ while (!m_ShouldExit.load())
+ {
+ ZEN_DEBUG("management cycle: target={} estimated={} running={} active={}",
+ m_TargetCoreCount.load(),
+ m_EstimatedCoreCount.load(),
+ m_RunningJobCount.load(),
+ [this] {
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+ return m_Jobs.size();
+ }());
+
+ SubmitNewJobs();
+ PollExistingJobs();
+ CleanupDeadJobs();
+
+ // Wait up to 5 seconds or until woken
+ std::unique_lock<std::mutex> Lock(m_WakeMutex);
+ m_WakeCV.wait_for(Lock, std::chrono::seconds(5), [this] { return m_ShouldExit.load(); });
+ }
+
+ ZEN_INFO("Nomad management thread exiting");
+}
+
+void
+NomadProvisioner::SubmitNewJobs()
+{
+ ZEN_TRACE_CPU("NomadProvisioner::SubmitNewJobs");
+
+ const uint32_t CoresPerJob = static_cast<uint32_t>(m_Config.CoresPerJob);
+
+ while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load())
+ {
+ {
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+ if (static_cast<int>(m_Jobs.size()) >= m_Config.MaxJobs)
+ {
+ ZEN_INFO("Nomad max jobs limit reached ({})", m_Config.MaxJobs);
+ break;
+ }
+ }
+
+ if (m_ShouldExit.load())
+ {
+ break;
+ }
+
+ const std::string JobId = GenerateJobId();
+
+ ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load());
+
+ const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint);
+
+ NomadJobInfo JobInfo;
+ JobInfo.Id = JobId;
+
+ if (!m_Client->SubmitJob(JobJson, JobInfo))
+ {
+ ZEN_WARN("failed to submit Nomad job '{}'", JobId);
+ break;
+ }
+
+ TrackedJob Tracked;
+ Tracked.JobId = JobId;
+ Tracked.Status = "pending";
+ Tracked.Cores = static_cast<int>(CoresPerJob);
+
+ {
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+ m_Jobs.push_back(std::move(Tracked));
+ }
+
+ m_EstimatedCoreCount.fetch_add(CoresPerJob);
+
+ ZEN_INFO("Nomad job '{}' submitted (estimated cores: {})", JobId, m_EstimatedCoreCount.load());
+ }
+}
+
+void
+NomadProvisioner::PollExistingJobs()
+{
+ ZEN_TRACE_CPU("NomadProvisioner::PollExistingJobs");
+
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+
+ for (auto& Job : m_Jobs)
+ {
+ if (m_ShouldExit.load())
+ {
+ break;
+ }
+
+ NomadJobInfo Info;
+ if (!m_Client->GetJobStatus(Job.JobId, Info))
+ {
+ ZEN_DEBUG("failed to poll status for job '{}'", Job.JobId);
+ continue;
+ }
+
+ const std::string PrevStatus = Job.Status;
+ Job.Status = Info.Status;
+
+ if (PrevStatus != Job.Status)
+ {
+ ZEN_INFO("Nomad job '{}' status changed: {} -> {}", Job.JobId, PrevStatus, Job.Status);
+
+ if (Job.Status == "running" && PrevStatus != "running")
+ {
+ m_RunningJobCount.fetch_add(1);
+ }
+ else if (Job.Status != "running" && PrevStatus == "running")
+ {
+ m_RunningJobCount.fetch_sub(1);
+ }
+ }
+ }
+}
+
+void
+NomadProvisioner::CleanupDeadJobs()
+{
+ ZEN_TRACE_CPU("NomadProvisioner::CleanupDeadJobs");
+
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+
+ for (auto It = m_Jobs.begin(); It != m_Jobs.end();)
+ {
+ if (It->Status == "dead")
+ {
+ ZEN_INFO("Nomad job '{}' is dead, removing from tracked jobs", It->JobId);
+ m_EstimatedCoreCount.fetch_sub(static_cast<uint32_t>(It->Cores));
+ It = m_Jobs.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+}
+
+void
+NomadProvisioner::StopAllJobs()
+{
+ ZEN_TRACE_CPU("NomadProvisioner::StopAllJobs");
+
+ std::lock_guard<std::mutex> Lock(m_JobsLock);
+
+ for (const auto& Job : m_Jobs)
+ {
+ ZEN_INFO("stopping Nomad job '{}' during shutdown", Job.JobId);
+ m_Client->StopJob(Job.JobId);
+ }
+
+ m_Jobs.clear();
+ m_EstimatedCoreCount.store(0);
+ m_RunningJobCount.store(0);
+}
+
+} // namespace zen::nomad
diff --git a/src/zennomad/xmake.lua b/src/zennomad/xmake.lua
new file mode 100644
index 000000000..ef1a8b201
--- /dev/null
+++ b/src/zennomad/xmake.lua
@@ -0,0 +1,10 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zennomad')
+ set_kind("static")
+ set_group("libs")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_deps("zencore", "zenhttp", "zenutil")
+ add_packages("json11")
diff --git a/src/zenremotestore-test/zenremotestore-test.cpp b/src/zenremotestore-test/zenremotestore-test.cpp
index 5db185041..dc47c5aed 100644
--- a/src/zenremotestore-test/zenremotestore-test.cpp
+++ b/src/zenremotestore-test/zenremotestore-test.cpp
@@ -1,46 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/trace.h>
-#include <zenremotestore/projectstore/remoteprojectstore.h>
+#include <zencore/testing.h>
#include <zenremotestore/zenremotestore.h>
#include <zencore/memory/newdelete.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
-
int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zenremotestore_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zenstore-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
-
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zenremotestore-test", zen::zenremotestore_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zenremotestore/builds/buildmanifest.cpp b/src/zenremotestore/builds/buildmanifest.cpp
index 051436e96..738e4b33b 100644
--- a/src/zenremotestore/builds/buildmanifest.cpp
+++ b/src/zenremotestore/builds/buildmanifest.cpp
@@ -97,6 +97,8 @@ ParseBuildManifest(const std::filesystem::path& ManifestPath)
}
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("remotestore.buildmanifest");
+
TEST_CASE("buildmanifest.unstructured")
{
ScopedTemporaryDirectory Root;
@@ -163,6 +165,8 @@ TEST_CASE("buildmanifest.structured")
CHECK_EQ(Manifest.Parts[1].Files[0].generic_string(), "baz.pdb");
}
+TEST_SUITE_END();
+
void
buildmanifest_forcelink()
{
diff --git a/src/zenremotestore/builds/buildsavedstate.cpp b/src/zenremotestore/builds/buildsavedstate.cpp
index 1d1f4605f..0685bf679 100644
--- a/src/zenremotestore/builds/buildsavedstate.cpp
+++ b/src/zenremotestore/builds/buildsavedstate.cpp
@@ -588,6 +588,8 @@ namespace buildsavestate_test {
}
} // namespace buildsavestate_test
+TEST_SUITE_BEGIN("remotestore.buildsavedstate");
+
TEST_CASE("buildsavestate.BuildsSelection")
{
using namespace buildsavestate_test;
@@ -696,6 +698,8 @@ TEST_CASE("buildsavestate.DownloadedPaths")
}
}
+TEST_SUITE_END();
+
#endif // ZEN_WITH_TESTS
} // namespace zen
diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp
index 07fcd62ba..00765903d 100644
--- a/src/zenremotestore/builds/buildstoragecache.cpp
+++ b/src/zenremotestore/builds/buildstoragecache.cpp
@@ -151,7 +151,7 @@ public:
auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); });
HttpClient::Response CacheResponse =
- m_HttpClient.Upload(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash.ToHexString()),
+ m_HttpClient.Upload(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash),
Payload,
ContentType);
@@ -180,7 +180,7 @@ public:
}
CreateDirectories(m_TempFolderPath);
HttpClient::Response CacheResponse =
- m_HttpClient.Download(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash.ToHexString()),
+ m_HttpClient.Download(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash),
m_TempFolderPath,
Headers);
AddStatistic(CacheResponse);
@@ -191,6 +191,74 @@ public:
return {};
}
+ virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId,
+ const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_TRACE_CPU("ZenBuildStorageCache::GetBuildBlobRanges");
+
+ Stopwatch ExecutionTimer;
+ auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); });
+
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ {
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ Writer.BeginObject();
+ {
+ Writer.AddInteger("offset"sv, Range.first);
+ Writer.AddInteger("length"sv, Range.second);
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray(); // ranges
+
+ CreateDirectories(m_TempFolderPath);
+ HttpClient::Response CacheResponse =
+ m_HttpClient.Post(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ AddStatistic(CacheResponse);
+ if (CacheResponse.IsSuccess())
+ {
+ CbPackage ResponsePackage = ParsePackageMessage(CacheResponse.ResponsePayload);
+ CbObjectView ResponseObject = ResponsePackage.GetObject();
+
+ CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView();
+
+ std::vector<std::pair<uint64_t, uint64_t>> ReceivedRanges;
+ ReceivedRanges.reserve(RangeArray.Num());
+
+ uint64_t OffsetInPayloadRanges = 0;
+
+ for (CbFieldView View : RangeArray)
+ {
+ CbObjectView RangeView = View.AsObjectView();
+ uint64_t Offset = RangeView["offset"sv].AsUInt64();
+ uint64_t Length = RangeView["length"sv].AsUInt64();
+
+ const std::pair<uint64_t, uint64_t>& Range = Ranges[ReceivedRanges.size()];
+
+ if (Offset != Range.first || Length != Range.second)
+ {
+ return {};
+ }
+ ReceivedRanges.push_back(std::make_pair(OffsetInPayloadRanges, Length));
+ OffsetInPayloadRanges += Length;
+ }
+
+ const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash);
+ if (DataAttachment)
+ {
+ SharedBuffer PayloadRanges = DataAttachment->AsBinary();
+ return BuildBlobRanges{.PayloadBuffer = PayloadRanges.AsIoBuffer(), .Ranges = std::move(ReceivedRanges)};
+ }
+ }
+ return {};
+ }
+
virtual void PutBlobMetadatas(const Oid& BuildId, std::span<const IoHash> BlobHashes, std::span<const CbObject> MetaDatas) override
{
ZEN_ASSERT(!IsFlushed);
@@ -460,6 +528,192 @@ CreateZenBuildStorageCache(HttpClient& HttpClient,
return std::make_unique<ZenBuildStorageCache>(HttpClient, Stats, Namespace, Bucket, TempFolderPath, BackgroundWorkerPool);
}
+#if ZEN_WITH_TESTS
+
+class InMemoryBuildStorageCache : public BuildStorageCache
+{
+public:
+ // MaxRangeSupported == 0 : no range requests are accepted, always return full blob
+ // MaxRangeSupported == 1 : single range is supported, multi range returns full blob
+ // MaxRangeSupported > 1 : multirange is supported up to MaxRangeSupported, more ranges returns empty blob (bad request)
+ explicit InMemoryBuildStorageCache(uint64_t MaxRangeSupported,
+ BuildStorageCache::Statistics& Stats,
+ double LatencySec = 0.0,
+ double DelayPerKBSec = 0.0)
+ : m_MaxRangeSupported(MaxRangeSupported)
+ , m_Stats(Stats)
+ , m_LatencySec(LatencySec)
+ , m_DelayPerKBSec(DelayPerKBSec)
+ {
+ }
+ void PutBuildBlob(const Oid&, const IoHash& RawHash, ZenContentType, const CompositeBuffer& Payload) override
+ {
+ IoBuffer Buf = Payload.Flatten().AsIoBuffer();
+ Buf.MakeOwned();
+ const uint64_t SentBytes = Buf.Size();
+ uint64_t ReceivedBytes = 0;
+ SimulateLatency(SentBytes, 0);
+ auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); });
+ Stopwatch ExecutionTimer;
+ auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); });
+ {
+ std::lock_guard Lock(m_Mutex);
+ m_Entries[RawHash] = std::move(Buf);
+ }
+ m_Stats.PutBlobCount.fetch_add(1);
+ m_Stats.PutBlobByteCount.fetch_add(SentBytes);
+ }
+
+ IoBuffer GetBuildBlob(const Oid&, const IoHash& RawHash, uint64_t RangeOffset = 0, uint64_t RangeBytes = (uint64_t)-1) override
+ {
+ uint64_t SentBytes = 0;
+ uint64_t ReceivedBytes = 0;
+ SimulateLatency(SentBytes, 0);
+ auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); });
+ Stopwatch ExecutionTimer;
+ auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); });
+ IoBuffer FullPayload;
+ {
+ std::lock_guard Lock(m_Mutex);
+ auto It = m_Entries.find(RawHash);
+ if (It == m_Entries.end())
+ {
+ return {};
+ }
+ FullPayload = It->second;
+ }
+
+ if (RangeOffset != 0 || RangeBytes != (uint64_t)-1)
+ {
+ if (m_MaxRangeSupported == 0)
+ {
+ ReceivedBytes = FullPayload.Size();
+ return FullPayload;
+ }
+ else
+ {
+ ReceivedBytes = (RangeBytes == (uint64_t)-1) ? FullPayload.Size() - RangeOffset : RangeBytes;
+ return IoBuffer(FullPayload, RangeOffset, RangeBytes);
+ }
+ }
+ else
+ {
+ ReceivedBytes = FullPayload.Size();
+ return FullPayload;
+ }
+ }
+
+ BuildBlobRanges GetBuildBlobRanges(const Oid&, const IoHash& RawHash, std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_ASSERT(!Ranges.empty());
+ uint64_t SentBytes = 0;
+ uint64_t ReceivedBytes = 0;
+ SimulateLatency(SentBytes, 0);
+ auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); });
+ Stopwatch ExecutionTimer;
+ auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); });
+ if (m_MaxRangeSupported > 1 && Ranges.size() > m_MaxRangeSupported)
+ {
+ return {};
+ }
+ IoBuffer FullPayload;
+ {
+ std::lock_guard Lock(m_Mutex);
+ auto It = m_Entries.find(RawHash);
+ if (It == m_Entries.end())
+ {
+ return {};
+ }
+ FullPayload = It->second;
+ }
+
+ if (Ranges.size() > m_MaxRangeSupported)
+ {
+ // An empty Ranges signals to the caller: "full buffer given, use it for all requested ranges".
+ ReceivedBytes = FullPayload.Size();
+ return {.PayloadBuffer = FullPayload};
+ }
+ else
+ {
+ uint64_t PayloadStart = Ranges.front().first;
+ uint64_t PayloadSize = Ranges.back().first + Ranges.back().second - PayloadStart;
+ IoBuffer RangeBuffer = IoBuffer(FullPayload, PayloadStart, PayloadSize);
+ std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges;
+ PayloadRanges.reserve(Ranges.size());
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ PayloadRanges.push_back(std::make_pair(Range.first - PayloadStart, Range.second));
+ }
+ ReceivedBytes = PayloadSize;
+ return {.PayloadBuffer = RangeBuffer, .Ranges = std::move(PayloadRanges)};
+ }
+ }
+
+ void PutBlobMetadatas(const Oid&, std::span<const IoHash>, std::span<const CbObject>) override {}
+
+ std::vector<CbObject> GetBlobMetadatas(const Oid&, std::span<const IoHash> Hashes) override
+ {
+ return std::vector<CbObject>(Hashes.size());
+ }
+
+ std::vector<BlobExistsResult> BlobsExists(const Oid&, std::span<const IoHash> Hashes) override
+ {
+ std::lock_guard Lock(m_Mutex);
+ std::vector<BlobExistsResult> Result;
+ Result.reserve(Hashes.size());
+ for (const IoHash& Hash : Hashes)
+ {
+ auto It = m_Entries.find(Hash);
+ Result.push_back({.HasBody = (It != m_Entries.end() && It->second)});
+ }
+ return Result;
+ }
+
+ void Flush(int32_t, std::function<bool(intptr_t)>&&) override {}
+
+private:
+ void AddStatistic(uint64_t ElapsedTimeUs, uint64_t ReceivedBytes, uint64_t SentBytes)
+ {
+ m_Stats.TotalBytesWritten += SentBytes;
+ m_Stats.TotalBytesRead += ReceivedBytes;
+ m_Stats.TotalExecutionTimeUs += ElapsedTimeUs;
+ m_Stats.TotalRequestCount++;
+ SetAtomicMax(m_Stats.PeakSentBytes, SentBytes);
+ SetAtomicMax(m_Stats.PeakReceivedBytes, ReceivedBytes);
+ if (ElapsedTimeUs > 0)
+ {
+ SetAtomicMax(m_Stats.PeakBytesPerSec, (ReceivedBytes + SentBytes) * 1000000 / ElapsedTimeUs);
+ }
+ }
+
+ void SimulateLatency(uint64_t SendBytes, uint64_t ReceiveBytes)
+ {
+ double SleepSec = m_LatencySec;
+ if (m_DelayPerKBSec > 0.0)
+ {
+ SleepSec += m_DelayPerKBSec * (double(SendBytes + ReceiveBytes) / 1024u);
+ }
+ if (SleepSec > 0)
+ {
+ Sleep(int(SleepSec * 1000));
+ }
+ }
+
+ uint64_t m_MaxRangeSupported = 0;
+ BuildStorageCache::Statistics& m_Stats;
+ const double m_LatencySec = 0.0;
+ const double m_DelayPerKBSec = 0.0;
+ std::mutex m_Mutex;
+ std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> m_Entries;
+};
+
+std::unique_ptr<BuildStorageCache>
+CreateInMemoryBuildStorageCache(uint64_t MaxRangeSupported, BuildStorageCache::Statistics& Stats, double LatencySec, double DelayPerKBSec)
+{
+ return std::make_unique<InMemoryBuildStorageCache>(MaxRangeSupported, Stats, LatencySec, DelayPerKBSec);
+}
+#endif // ZEN_WITH_TESTS
+
ZenCacheEndpointTestResult
TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose)
{
@@ -474,9 +728,28 @@ TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const boo
HttpClient::Response TestResponse = TestHttpClient.Get("/status/builds");
if (TestResponse.IsSuccess())
{
- return {.Success = true};
+ uint64_t MaxRangeCountPerRequest = 1;
+ CbObject StatusResponse = TestResponse.AsObject();
+ if (StatusResponse["ok"].AsBool())
+ {
+ MaxRangeCountPerRequest = StatusResponse["capabilities"].AsObjectView()["maxrangecountperrequest"].AsUInt64(1);
+
+ LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health");
+
+ if (!LatencyResult.Success)
+ {
+ return {.Success = false, .FailureReason = LatencyResult.FailureReason};
+ }
+
+ return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds, .MaxRangeCountPerRequest = MaxRangeCountPerRequest};
+ }
+ else
+ {
+ return {.Success = false,
+ .FailureReason = fmt::format("ZenCache endpoint {}/status/builds did not respond with \"ok\"", BaseUrl)};
+ }
}
return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")};
-};
+}
} // namespace zen
diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp
index 2319ad66d..f4b167b73 100644
--- a/src/zenremotestore/builds/buildstorageoperations.cpp
+++ b/src/zenremotestore/builds/buildstorageoperations.cpp
@@ -38,6 +38,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
#if ZEN_WITH_TESTS
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zenhttp/httpclientauth.h>
# include <zenremotestore/builds/filebuildstorage.h>
#endif // ZEN_WITH_TESTS
@@ -484,24 +485,6 @@ private:
uint64_t FilteredPerSecond = 0;
};
-EPartialBlockRequestMode
-PartialBlockRequestModeFromString(const std::string_view ModeString)
-{
- switch (HashStringAsLowerDjb2(ModeString))
- {
- case HashStringDjb2("false"):
- return EPartialBlockRequestMode::Off;
- case HashStringDjb2("zencacheonly"):
- return EPartialBlockRequestMode::ZenCacheOnly;
- case HashStringDjb2("mixed"):
- return EPartialBlockRequestMode::Mixed;
- case HashStringDjb2("true"):
- return EPartialBlockRequestMode::All;
- default:
- return EPartialBlockRequestMode::Invalid;
- }
-}
-
std::filesystem::path
ZenStateFilePath(const std::filesystem::path& ZenFolderPath)
{
@@ -579,13 +562,6 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
CreateDirectories(m_TempDownloadFolderPath);
CreateDirectories(m_TempBlockFolderPath);
- Stopwatch IndexTimer;
-
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs()));
- }
-
Stopwatch CacheMappingTimer;
std::vector<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters(m_RemoteContent.ChunkedContent.SequenceRawHashes.size());
@@ -906,343 +882,240 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
CheckRequiredDiskSpace(RemotePathToRemoteIndex);
+ BlobsExistsResult ExistsResult;
{
- ZEN_TRACE_CPU("WriteChunks");
-
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount);
-
- Stopwatch WriteTimer;
-
- FilteredRate FilteredDownloadedBytesPerSecond;
- FilteredRate FilteredWrittenBytesPerSecond;
-
- std::unique_ptr<OperationLogOutput::ProgressBar> WriteProgressBarPtr(
- m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing"));
- OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr);
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ ChunkBlockAnalyser BlockAnalyser(
+ m_LogOutput,
+ m_BlockDescriptions,
+ ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet,
+ .IsVerbose = m_Options.IsVerbose,
+ .HostLatencySec = m_Storage.BuildStorageHost.LatencySec,
+ .HostHighSpeedLatencySec = m_Storage.CacheHost.LatencySec,
+ .HostMaxRangeCountPerRequest = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest,
+ .HostHighSpeedMaxRangeCountPerRequest = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest});
- struct LooseChunkHashWorkData
- {
- std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs;
- uint32_t RemoteChunkIndex = (uint32_t)-1;
- };
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = BlockAnalyser.GetNeeded(
+ m_RemoteLookup.ChunkHashToChunkIndex,
+ [&](uint32_t RemoteChunkIndex) -> bool { return RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex]; });
- std::vector<LooseChunkHashWorkData> LooseChunkHashWorks;
- TotalPartWriteCount += CopyChunkDatas.size();
- TotalPartWriteCount += ScavengedSequenceCopyOperations.size();
+ std::vector<uint32_t> FetchBlockIndexes;
+ std::vector<uint32_t> CachedChunkBlockIndexes;
- for (const IoHash ChunkHash : m_LooseChunkHashes)
{
- auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
- ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
- const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
- if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex])
+ ZEN_TRACE_CPU("BlockCacheFileExists");
+ for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks)
{
- if (m_Options.IsVerbose)
+ if (m_Options.PrimeCacheOnly)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Skipping chunk {} due to cache reuse", ChunkHash);
- }
- continue;
- }
- bool NeedsCopy = true;
- if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false))
- {
- std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs =
- GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex);
-
- if (ChunkTargetPtrs.empty())
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Skipping chunk {} due to cache reuse", ChunkHash);
- }
+ FetchBlockIndexes.push_back(NeededBlock.BlockIndex);
}
else
{
- TotalRequestCount++;
- TotalPartWriteCount++;
- LooseChunkHashWorks.push_back(
- LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex});
- }
- }
- }
-
- uint32_t BlockCount = gsl::narrow<uint32_t>(m_BlockDescriptions.size());
-
- std::vector<bool> ChunkIsPickedUpByBlock(m_RemoteContent.ChunkedContent.ChunkHashes.size(), false);
- auto GetNeededChunkBlockIndexes = [this, &RemoteChunkIndexNeedsCopyFromSourceFlags, &ChunkIsPickedUpByBlock](
- const ChunkBlockDescription& BlockDescription) {
- ZEN_TRACE_CPU("GetNeededChunkBlockIndexes");
- std::vector<uint32_t> NeededBlockChunkIndexes;
- for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++)
- {
- const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex];
- if (auto It = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); It != m_RemoteLookup.ChunkHashToChunkIndex.end())
- {
- const uint32_t RemoteChunkIndex = It->second;
- if (!ChunkIsPickedUpByBlock[RemoteChunkIndex])
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex];
+ bool UsingCachedBlock = false;
+ if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end())
{
- if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex])
+ TotalPartWriteCount++;
+
+ std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
+ if (IsFile(BlockPath))
{
- ChunkIsPickedUpByBlock[RemoteChunkIndex] = true;
- NeededBlockChunkIndexes.push_back(ChunkBlockIndex);
+ CachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex);
+ UsingCachedBlock = true;
}
}
- }
- else
- {
- ZEN_DEBUG("Chunk {} not found in block {}", ChunkHash, BlockDescription.BlockHash);
+ if (!UsingCachedBlock)
+ {
+ FetchBlockIndexes.push_back(NeededBlock.BlockIndex);
+ }
}
}
- return NeededBlockChunkIndexes;
- };
+ }
- std::vector<uint32_t> CachedChunkBlockIndexes;
- std::vector<uint32_t> FetchBlockIndexes;
- std::vector<std::vector<uint32_t>> AllBlockChunkIndexNeeded;
+ std::vector<uint32_t> NeededLooseChunkIndexes;
- for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; BlockIndex++)
{
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
-
- std::vector<uint32_t> BlockChunkIndexNeeded = GetNeededChunkBlockIndexes(BlockDescription);
- if (!BlockChunkIndexNeeded.empty())
+ NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size());
+ for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++)
{
- if (m_Options.PrimeCacheOnly)
+ const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex];
+ auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
+ ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
+ const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
+
+ if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex])
{
- FetchBlockIndexes.push_back(BlockIndex);
+ if (m_Options.IsVerbose)
+ {
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Skipping chunk {} due to cache reuse",
+ m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]);
+ }
+ continue;
}
- else
+
+ bool NeedsCopy = true;
+ if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false))
{
- bool UsingCachedBlock = false;
- if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end())
+ uint64_t WriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex);
+ if (WriteCount == 0)
{
- TotalPartWriteCount++;
-
- std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
- if (IsFile(BlockPath))
+ if (m_Options.IsVerbose)
{
- CachedChunkBlockIndexes.push_back(BlockIndex);
- UsingCachedBlock = true;
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Skipping chunk {} due to cache reuse",
+ m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]);
}
}
- if (!UsingCachedBlock)
+ else
{
- FetchBlockIndexes.push_back(BlockIndex);
+ NeededLooseChunkIndexes.push_back(LooseChunkIndex);
}
}
}
- AllBlockChunkIndexNeeded.emplace_back(std::move(BlockChunkIndexNeeded));
}
- BlobsExistsResult ExistsResult;
-
- if (m_Storage.BuildCacheStorage)
+ if (m_Storage.CacheStorage)
{
ZEN_TRACE_CPU("BlobCacheExistCheck");
Stopwatch Timer;
- tsl::robin_set<IoHash> BlobHashesSet;
+ std::vector<IoHash> BlobHashes;
+ BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size());
- BlobHashesSet.reserve(LooseChunkHashWorks.size() + FetchBlockIndexes.size());
- for (LooseChunkHashWorkData& LooseChunkHashWork : LooseChunkHashWorks)
+ for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes)
{
- BlobHashesSet.insert(m_RemoteContent.ChunkedContent.ChunkHashes[LooseChunkHashWork.RemoteChunkIndex]);
+ BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]);
}
+
for (uint32_t BlockIndex : FetchBlockIndexes)
{
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
- BlobHashesSet.insert(BlockDescription.BlockHash);
+ BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash);
}
- if (!BlobHashesSet.empty())
- {
- const std::vector<IoHash> BlobHashes(BlobHashesSet.begin(), BlobHashesSet.end());
- const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult =
- m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes);
+ const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult =
+ m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes);
- if (CacheExistsResult.size() == BlobHashes.size())
+ if (CacheExistsResult.size() == BlobHashes.size())
+ {
+ ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size());
+ for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++)
{
- ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size());
- for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++)
+ if (CacheExistsResult[BlobIndex].HasBody)
{
- if (CacheExistsResult[BlobIndex].HasBody)
- {
- ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]);
- }
+ ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]);
}
}
- ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs();
- if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Remote cache : Found {} out of {} needed blobs in {}",
- ExistsResult.ExistingBlobs.size(),
- BlobHashes.size(),
- NiceTimeSpanMs(ExistsResult.ElapsedTimeMs));
- }
+ }
+ ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs();
+ if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet)
+ {
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Remote cache : Found {} out of {} needed blobs in {}",
+ ExistsResult.ExistingBlobs.size(),
+ BlobHashes.size(),
+ NiceTimeSpanMs(ExistsResult.ElapsedTimeMs));
}
}
- std::vector<BlockRangeDescriptor> BlockRangeWorks;
- std::vector<uint32_t> FullBlockWorks;
+ std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> BlockPartialDownloadModes;
+
+ if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off)
{
- Stopwatch Timer;
+ BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off);
+ }
+ else
+ {
+ ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+ ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
- std::vector<uint32_t> PartialBlockIndexes;
+ switch (m_Options.PartialBlockRequestMode)
+ {
+ case EPartialBlockRequestMode::Off:
+ break;
+ case EPartialBlockRequestMode::ZenCacheOnly:
+ CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+ break;
+ case EPartialBlockRequestMode::Mixed:
+ CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
+ break;
+ case EPartialBlockRequestMode::All:
+ CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ CloudPartialDownloadMode = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
+ break;
+ default:
+ ZEN_ASSERT(false);
+ break;
+ }
- for (uint32_t BlockIndex : FetchBlockIndexes)
+ BlockPartialDownloadModes.reserve(m_BlockDescriptions.size());
+ for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++)
{
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+ const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash);
+ BlockPartialDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode);
+ }
+ }
- const std::vector<uint32_t> BlockChunkIndexNeeded = std::move(AllBlockChunkIndexNeeded[BlockIndex]);
- if (!BlockChunkIndexNeeded.empty())
- {
- bool WantsToDoPartialBlockDownload = BlockChunkIndexNeeded.size() < BlockDescription.ChunkRawHashes.size();
- bool CanDoPartialBlockDownload =
- (BlockDescription.HeaderSize > 0) &&
- (BlockDescription.ChunkCompressedLengths.size() == BlockDescription.ChunkRawHashes.size());
-
- bool AllowedToDoPartialRequest = false;
- bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash);
- switch (m_Options.PartialBlockRequestMode)
- {
- case EPartialBlockRequestMode::Off:
- break;
- case EPartialBlockRequestMode::ZenCacheOnly:
- AllowedToDoPartialRequest = BlockExistInCache;
- break;
- case EPartialBlockRequestMode::Mixed:
- case EPartialBlockRequestMode::All:
- AllowedToDoPartialRequest = true;
- break;
- default:
- ZEN_ASSERT(false);
- break;
- }
+ ZEN_ASSERT(BlockPartialDownloadModes.size() == m_BlockDescriptions.size());
- const uint32_t ChunkStartOffsetInBlock =
- gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize);
+ ChunkBlockAnalyser::BlockResult PartialBlocks =
+ BlockAnalyser.CalculatePartialBlockDownloads(NeededBlocks, BlockPartialDownloadModes);
- const uint64_t TotalBlockSize = std::accumulate(BlockDescription.ChunkCompressedLengths.begin(),
- BlockDescription.ChunkCompressedLengths.end(),
- std::uint64_t(ChunkStartOffsetInBlock));
+ struct LooseChunkHashWorkData
+ {
+ std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs;
+ uint32_t RemoteChunkIndex = (uint32_t)-1;
+ };
- if (AllowedToDoPartialRequest && WantsToDoPartialBlockDownload && CanDoPartialBlockDownload)
- {
- ZEN_TRACE_CPU("PartialBlockAnalysis");
-
- bool LimitToSingleRange =
- BlockExistInCache ? false : m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed;
- uint64_t TotalWantedChunksSize = 0;
- std::optional<std::vector<BlockRangeDescriptor>> MaybeBlockRanges =
- CalculateBlockRanges(BlockIndex,
- BlockDescription,
- BlockChunkIndexNeeded,
- LimitToSingleRange,
- ChunkStartOffsetInBlock,
- TotalBlockSize,
- TotalWantedChunksSize);
- ZEN_ASSERT(TotalWantedChunksSize <= TotalBlockSize);
-
- if (MaybeBlockRanges.has_value())
- {
- const std::vector<BlockRangeDescriptor>& BlockRanges = MaybeBlockRanges.value();
- ZEN_ASSERT(!BlockRanges.empty());
- BlockRangeWorks.insert(BlockRangeWorks.end(), BlockRanges.begin(), BlockRanges.end());
- TotalRequestCount += BlockRanges.size();
- TotalPartWriteCount += BlockRanges.size();
-
- uint64_t RequestedSize = std::accumulate(
- BlockRanges.begin(),
- BlockRanges.end(),
- uint64_t(0),
- [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; });
- PartialBlockIndexes.push_back(BlockIndex);
-
- if (RequestedSize > TotalWantedChunksSize)
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Requesting {} chunks ({}) from block {} ({}) using {} requests (extra bytes {})",
- BlockChunkIndexNeeded.size(),
- NiceBytes(RequestedSize),
- BlockDescription.BlockHash,
- NiceBytes(TotalBlockSize),
- BlockRanges.size(),
- NiceBytes(RequestedSize - TotalWantedChunksSize));
- }
- }
- }
- else
- {
- FullBlockWorks.push_back(BlockIndex);
- TotalRequestCount++;
- TotalPartWriteCount++;
- }
- }
- else
- {
- FullBlockWorks.push_back(BlockIndex);
- TotalRequestCount++;
- TotalPartWriteCount++;
- }
- }
- }
+ TotalRequestCount += NeededLooseChunkIndexes.size();
+ TotalPartWriteCount += NeededLooseChunkIndexes.size();
+ TotalRequestCount += PartialBlocks.BlockRanges.size();
+ TotalPartWriteCount += PartialBlocks.BlockRanges.size();
+ TotalRequestCount += PartialBlocks.FullBlockIndexes.size();
+ TotalPartWriteCount += PartialBlocks.FullBlockIndexes.size();
- if (!PartialBlockIndexes.empty())
- {
- uint64_t TotalFullBlockRequestBytes = 0;
- for (uint32_t BlockIndex : FullBlockWorks)
- {
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
- uint32_t CurrentOffset =
- gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize);
+ std::vector<LooseChunkHashWorkData> LooseChunkHashWorks;
+ for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes)
+ {
+ const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex];
+ auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
+ ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
+ const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
- TotalFullBlockRequestBytes += std::accumulate(BlockDescription.ChunkCompressedLengths.begin(),
- BlockDescription.ChunkCompressedLengths.end(),
- std::uint64_t(CurrentOffset));
- }
+ std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs =
+ GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex);
- uint64_t TotalPartialBlockBytes = 0;
- for (uint32_t BlockIndex : PartialBlockIndexes)
- {
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
- uint32_t CurrentOffset =
- gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize);
+ ZEN_ASSERT(!ChunkTargetPtrs.empty());
+ LooseChunkHashWorks.push_back(
+ LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex});
+ }
- TotalPartialBlockBytes += std::accumulate(BlockDescription.ChunkCompressedLengths.begin(),
- BlockDescription.ChunkCompressedLengths.end(),
- std::uint64_t(CurrentOffset));
- }
+ ZEN_TRACE_CPU("WriteChunks");
- uint64_t NonPartialTotalBlockBytes = TotalFullBlockRequestBytes + TotalPartialBlockBytes;
+ m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount);
- const uint64_t TotalPartialBlockRequestBytes =
- std::accumulate(BlockRangeWorks.begin(),
- BlockRangeWorks.end(),
- uint64_t(0),
- [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; });
- uint64_t TotalExtraPartialBlocksRequests = BlockRangeWorks.size() - PartialBlockIndexes.size();
+ Stopwatch WriteTimer;
- uint64_t TotalSavedBlocksSize = TotalPartialBlockBytes - TotalPartialBlockRequestBytes;
- double SavedSizePercent = (TotalSavedBlocksSize * 100.0) / NonPartialTotalBlockBytes;
+ FilteredRate FilteredDownloadedBytesPerSecond;
+ FilteredRate FilteredWrittenBytesPerSecond;
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Analysis of partial block requests saves download of {} out of {} ({:.1f}%) using {} extra "
- "requests. Completed in {}",
- NiceBytes(TotalSavedBlocksSize),
- NiceBytes(NonPartialTotalBlockBytes),
- SavedSizePercent,
- TotalExtraPartialBlocksRequests,
- NiceTimeSpanMs(ExistsResult.ElapsedTimeMs));
- }
- }
- }
+ std::unique_ptr<OperationLogOutput::ProgressBar> WriteProgressBarPtr(
+ m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing"));
+ OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr);
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+
+ TotalPartWriteCount += CopyChunkDatas.size();
+ TotalPartWriteCount += ScavengedSequenceCopyOperations.size();
BufferedWriteFileCache WriteCache;
@@ -1472,13 +1345,23 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
});
}
- for (size_t BlockRangeIndex = 0; BlockRangeIndex < BlockRangeWorks.size(); BlockRangeIndex++)
+ for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();)
{
ZEN_ASSERT(!m_Options.PrimeCacheOnly);
if (m_AbortFlag)
{
break;
}
+
+ size_t RangeCount = 1;
+ size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex;
+ const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex];
+ while (RangeCount < RangesLeft &&
+ CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex)
+ {
+ RangeCount++;
+ }
+
Work.ScheduleWork(
m_NetworkPool,
[this,
@@ -1492,18 +1375,19 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
TotalPartWriteCount,
&FilteredWrittenBytesPerSecond,
&Work,
- &BlockRangeWorks,
- BlockRangeIndex](std::atomic<bool>&) {
+ &PartialBlocks,
+ BlockRangeStartIndex = BlockRangeIndex,
+ RangeCount = RangeCount](std::atomic<bool>&) {
if (!m_AbortFlag)
{
- ZEN_TRACE_CPU("Async_GetPartialBlock");
-
- const BlockRangeDescriptor& BlockRange = BlockRangeWorks[BlockRangeIndex];
+ ZEN_TRACE_CPU("Async_GetPartialBlockRanges");
FilteredDownloadedBytesPerSecond.Start();
DownloadPartialBlock(
- BlockRange,
+ PartialBlocks.BlockRanges,
+ BlockRangeStartIndex,
+ RangeCount,
ExistsResult,
[this,
&RemoteChunkIndexNeedsCopyFromSourceFlags,
@@ -1515,7 +1399,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
TotalPartWriteCount,
&FilteredDownloadedBytesPerSecond,
&FilteredWrittenBytesPerSecond,
- &BlockRange](IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath) {
+ &PartialBlocks](IoBuffer&& InMemoryBuffer,
+ const std::filesystem::path& OnDiskPath,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) {
if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
{
FilteredDownloadedBytesPerSecond.Stop();
@@ -1533,14 +1420,18 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
&Work,
TotalPartWriteCount,
&FilteredWrittenBytesPerSecond,
- &BlockRange,
+ &PartialBlocks,
+ BlockRangeStartIndex,
BlockChunkPath = std::filesystem::path(OnDiskPath),
- BlockPartialBuffer = std::move(InMemoryBuffer)](std::atomic<bool>&) mutable {
+ BlockPartialBuffer = std::move(InMemoryBuffer),
+ OffsetAndLengths = std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(),
+ OffsetAndLengths.end())](
+ std::atomic<bool>&) mutable {
if (!m_AbortFlag)
{
ZEN_TRACE_CPU("Async_WritePartialBlock");
- const uint32_t BlockIndex = BlockRange.BlockIndex;
+ const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex;
const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
@@ -1563,22 +1454,41 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
FilteredWrittenBytesPerSecond.Start();
- if (!WritePartialBlockChunksToCache(
- BlockDescription,
- SequenceIndexChunksLeftToWriteCounters,
- Work,
- CompositeBuffer(std::move(BlockPartialBuffer)),
- BlockRange.ChunkBlockIndexStart,
- BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount - 1,
- RemoteChunkIndexNeedsCopyFromSourceFlags,
- WriteCache))
+ size_t RangeCount = OffsetAndLengths.size();
+
+ for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++)
{
- std::error_code DummyEc;
- RemoveFile(BlockChunkPath, DummyEc);
- throw std::runtime_error(
- fmt::format("Partial block {} is malformed", BlockDescription.BlockHash));
- }
+ const std::pair<uint64_t, uint64_t>& OffsetAndLength =
+ OffsetAndLengths[PartialRangeIndex];
+ IoBuffer BlockRangeBuffer(BlockPartialBuffer,
+ OffsetAndLength.first,
+ OffsetAndLength.second);
+
+ const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor =
+ PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex];
+
+ if (!WritePartialBlockChunksToCache(BlockDescription,
+ SequenceIndexChunksLeftToWriteCounters,
+ Work,
+ CompositeBuffer(std::move(BlockRangeBuffer)),
+ RangeDescriptor.ChunkBlockIndexStart,
+ RangeDescriptor.ChunkBlockIndexStart +
+ RangeDescriptor.ChunkBlockIndexCount - 1,
+ RemoteChunkIndexNeedsCopyFromSourceFlags,
+ WriteCache))
+ {
+ std::error_code DummyEc;
+ RemoveFile(BlockChunkPath, DummyEc);
+ throw std::runtime_error(
+ fmt::format("Partial block {} is malformed", BlockDescription.BlockHash));
+ }
+ WritePartsComplete++;
+ if (WritePartsComplete == TotalPartWriteCount)
+ {
+ FilteredWrittenBytesPerSecond.Stop();
+ }
+ }
std::error_code Ec = TryRemoveFile(BlockChunkPath);
if (Ec)
{
@@ -1588,12 +1498,6 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
Ec.value(),
Ec.message());
}
-
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
- {
- FilteredWrittenBytesPerSecond.Stop();
- }
}
},
OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog
@@ -1602,9 +1506,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
});
}
});
+ BlockRangeIndex += RangeCount;
}
- for (uint32_t BlockIndex : FullBlockWorks)
+ for (uint32_t BlockIndex : PartialBlocks.FullBlockIndexes)
{
if (m_AbortFlag)
{
@@ -1641,20 +1546,20 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
IoBuffer BlockBuffer;
const bool ExistsInCache =
- m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash);
+ m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash);
if (ExistsInCache)
{
- BlockBuffer = m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
+ BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
}
if (!BlockBuffer)
{
BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
- if (BlockBuffer && m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId,
- BlockDescription.BlockHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(BlockBuffer)));
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlockDescription.BlockHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(BlockBuffer)));
}
}
if (!BlockBuffer)
@@ -3217,10 +3122,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex];
// FilteredDownloadedBytesPerSecond.Start();
IoBuffer BuildBlob;
- const bool ExistsInCache = m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash);
+ const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash);
if (ExistsInCache)
{
- BuildBlob = m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, ChunkHash);
+ BuildBlob = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, ChunkHash);
}
if (BuildBlob)
{
@@ -3248,12 +3153,12 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
m_DownloadStats.DownloadedChunkCount++;
m_DownloadStats.RequestsCompleteCount++;
- if (Payload && m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId,
- ChunkHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(Payload)));
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ ChunkHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(Payload)));
}
OnDownloaded(std::move(Payload));
@@ -3262,12 +3167,12 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
else
{
BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash);
- if (BuildBlob && m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId,
- ChunkHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(BuildBlob)));
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ ChunkHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(BuildBlob)));
}
if (!BuildBlob)
{
@@ -3289,347 +3194,241 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
}
}
-BuildsOperationUpdateFolder::BlockRangeDescriptor
-BuildsOperationUpdateFolder::MergeBlockRanges(std::span<const BlockRangeDescriptor> Ranges)
+void
+BuildsOperationUpdateFolder::DownloadPartialBlock(
+ std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges,
+ size_t BlockRangeStartIndex,
+ size_t BlockRangeCount,
+ const BlobsExistsResult& ExistsResult,
+ std::function<void(IoBuffer&& InMemoryBuffer,
+ const std::filesystem::path& OnDiskPath,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>&& OnDownloaded)
{
- ZEN_ASSERT(Ranges.size() > 1);
- const BlockRangeDescriptor& First = Ranges.front();
- const BlockRangeDescriptor& Last = Ranges.back();
-
- return BlockRangeDescriptor{.BlockIndex = First.BlockIndex,
- .RangeStart = First.RangeStart,
- .RangeLength = Last.RangeStart + Last.RangeLength - First.RangeStart,
- .ChunkBlockIndexStart = First.ChunkBlockIndexStart,
- .ChunkBlockIndexCount = Last.ChunkBlockIndexStart + Last.ChunkBlockIndexCount - First.ChunkBlockIndexStart};
-}
+ const uint32_t BlockIndex = BlockRanges[BlockRangeStartIndex].BlockIndex;
-std::optional<std::vector<BuildsOperationUpdateFolder::BlockRangeDescriptor>>
-BuildsOperationUpdateFolder::MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, const BlockRangeDescriptor& Range)
-{
- if (Range.RangeLength == TotalBlockSize)
- {
- return {};
- }
- else
- {
- return std::vector<BlockRangeDescriptor>{Range};
- }
-};
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
-const BuildsOperationUpdateFolder::BlockRangeLimit*
-BuildsOperationUpdateFolder::GetBlockRangeLimitForRange(std::span<const BlockRangeLimit> Limits,
- uint64_t TotalBlockSize,
- std::span<const BlockRangeDescriptor> Ranges)
-{
- if (Ranges.size() > 1)
- {
- const std::uint64_t WantedSize =
- std::accumulate(Ranges.begin(), Ranges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) {
- return Current + Range.RangeLength;
- });
+ auto ProcessDownload = [this](
+ const ChunkBlockDescription& BlockDescription,
+ IoBuffer&& BlockRangeBuffer,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> BlockOffsetAndLengths,
+ const std::function<void(IoBuffer && InMemoryBuffer,
+ const std::filesystem::path& OnDiskPath,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>& OnDownloaded) {
+ uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize();
+ m_DownloadStats.DownloadedBlockCount++;
+ m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize;
+ m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size();
- const double RangeRequestedPercent = (WantedSize * 100.0) / TotalBlockSize;
+ std::filesystem::path BlockChunkPath;
- for (const BlockRangeLimit& Limit : Limits)
+ // Check if the dowloaded block is file based and we can move it directly without rewriting it
{
- if (RangeRequestedPercent >= Limit.SizePercent && Ranges.size() > Limit.MaxRangeCount)
+ IoBufferFileReference FileRef;
+ if (BlockRangeBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) &&
+ (FileRef.FileChunkSize == BlockRangeBufferSize))
{
- return &Limit;
- }
- }
- }
- return nullptr;
-};
+ ZEN_TRACE_CPU("MoveTempPartialBlock");
-std::vector<BuildsOperationUpdateFolder::BlockRangeDescriptor>
-BuildsOperationUpdateFolder::CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, std::span<const BlockRangeDescriptor> BlockRanges)
-{
- ZEN_ASSERT(BlockRanges.size() > 1);
- std::vector<BlockRangeDescriptor> CollapsedBlockRanges;
+ std::error_code Ec;
+ std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec);
+ if (!Ec)
+ {
+ BlockRangeBuffer.SetDeleteOnClose(false);
+ BlockRangeBuffer = {};
- auto BlockRangesIt = BlockRanges.begin();
- CollapsedBlockRanges.push_back(*BlockRangesIt++);
- for (; BlockRangesIt != BlockRanges.end(); BlockRangesIt++)
- {
- BlockRangeDescriptor& LastRange = CollapsedBlockRanges.back();
+ IoHashStream RangeId;
+ for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths)
+ {
+ RangeId.Append(&Range.first, sizeof(uint64_t));
+ RangeId.Append(&Range.second, sizeof(uint64_t));
+ }
+
+ BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash());
+ RenameFile(TempBlobPath, BlockChunkPath, Ec);
+ if (Ec)
+ {
+ BlockChunkPath = std::filesystem::path{};
- const uint64_t BothRangeSize = BlockRangesIt->RangeLength + LastRange.RangeLength;
+ // Re-open the temp file again
+ BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete);
+ BlockRangeBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockRangeBufferSize, true);
+ BlockRangeBuffer.SetDeleteOnClose(true);
+ }
+ }
+ }
+ }
- const uint64_t Gap = BlockRangesIt->RangeStart - (LastRange.RangeStart + LastRange.RangeLength);
- if (Gap <= Max(BothRangeSize / 16, AlwaysAcceptableGap))
+ if (BlockChunkPath.empty() && (BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize))
{
- LastRange.ChunkBlockIndexCount =
- (BlockRangesIt->ChunkBlockIndexStart + BlockRangesIt->ChunkBlockIndexCount) - LastRange.ChunkBlockIndexStart;
- LastRange.RangeLength = (BlockRangesIt->RangeStart + BlockRangesIt->RangeLength) - LastRange.RangeStart;
+ ZEN_TRACE_CPU("WriteTempPartialBlock");
+
+ IoHashStream RangeId;
+ for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths)
+ {
+ RangeId.Append(&Range.first, sizeof(uint64_t));
+ RangeId.Append(&Range.second, sizeof(uint64_t));
+ }
+
+ // Could not be moved and rather large, lets store it on disk
+ BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash());
+ TemporaryFile::SafeWriteFile(BlockChunkPath, BlockRangeBuffer);
+ BlockRangeBuffer = {};
}
- else
+ if (!m_AbortFlag)
{
- CollapsedBlockRanges.push_back(*BlockRangesIt);
+ OnDownloaded(std::move(BlockRangeBuffer), std::move(BlockChunkPath), BlockRangeStartIndex, BlockOffsetAndLengths);
}
- }
-
- return CollapsedBlockRanges;
-};
+ };
-uint64_t
-BuildsOperationUpdateFolder::CalculateNextGap(std::span<const BlockRangeDescriptor> BlockRanges)
-{
- ZEN_ASSERT(BlockRanges.size() > 1);
- uint64_t AcceptableGap = (uint64_t)-1;
- for (size_t RangeIndex = 0; RangeIndex < BlockRanges.size() - 1; RangeIndex++)
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+ Ranges.reserve(BlockRangeCount);
+ for (size_t BlockRangeIndex = BlockRangeStartIndex; BlockRangeIndex < BlockRangeStartIndex + BlockRangeCount; BlockRangeIndex++)
{
- const BlockRangeDescriptor& Range = BlockRanges[RangeIndex];
- const BlockRangeDescriptor& NextRange = BlockRanges[RangeIndex + 1];
-
- const uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength);
- AcceptableGap = Min(Gap, AcceptableGap);
+ const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRanges[BlockRangeIndex];
+ Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength));
}
- AcceptableGap = RoundUp(AcceptableGap, 16u * 1024u);
- return AcceptableGap;
-};
-std::optional<std::vector<BuildsOperationUpdateFolder::BlockRangeDescriptor>>
-BuildsOperationUpdateFolder::CalculateBlockRanges(uint32_t BlockIndex,
- const ChunkBlockDescription& BlockDescription,
- std::span<const uint32_t> BlockChunkIndexNeeded,
- bool LimitToSingleRange,
- const uint64_t ChunkStartOffsetInBlock,
- const uint64_t TotalBlockSize,
- uint64_t& OutTotalWantedChunksSize)
-{
- ZEN_TRACE_CPU("CalculateBlockRanges");
+ const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash);
- std::vector<BlockRangeDescriptor> BlockRanges;
+ size_t SubBlockRangeCount = BlockRangeCount;
+ size_t SubRangeCountComplete = 0;
+ std::span<const std::pair<uint64_t, uint64_t>> RangesSpan(Ranges);
+ while (SubRangeCountComplete < SubBlockRangeCount)
{
- uint64_t CurrentOffset = ChunkStartOffsetInBlock;
- uint32_t ChunkBlockIndex = 0;
- uint32_t NeedBlockChunkIndexOffset = 0;
- BlockRangeDescriptor NextRange{.BlockIndex = BlockIndex};
- while (NeedBlockChunkIndexOffset < BlockChunkIndexNeeded.size() && ChunkBlockIndex < BlockDescription.ChunkRawHashes.size())
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ // First try to get subrange from cache.
+ // If not successful, try to get the ranges from the build store and adapt SubRangeCount...
+
+ size_t SubRangeStartIndex = BlockRangeStartIndex + SubRangeCountComplete;
+ if (ExistsInCache)
{
- const uint32_t ChunkCompressedLength = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex];
- if (ChunkBlockIndex < BlockChunkIndexNeeded[NeedBlockChunkIndexOffset])
+ size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, m_Storage.CacheHost.Caps.MaxRangeCountPerRequest);
+
+ if (SubRangeCount == 1)
{
- if (NextRange.RangeLength > 0)
+ // Legacy single-range path, prefer that for max compatibility
+
+ const std::pair<uint64_t, uint64_t> SubRange = RangesSpan[SubRangeCountComplete];
+ IoBuffer PayloadBuffer =
+ m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, SubRange.first, SubRange.second);
+ if (m_AbortFlag)
{
- BlockRanges.push_back(NextRange);
- NextRange = {.BlockIndex = BlockIndex};
+ break;
}
- ChunkBlockIndex++;
- CurrentOffset += ChunkCompressedLength;
- }
- else if (ChunkBlockIndex == BlockChunkIndexNeeded[NeedBlockChunkIndexOffset])
- {
- if (NextRange.RangeLength == 0)
+ if (PayloadBuffer)
{
- NextRange.RangeStart = CurrentOffset;
- NextRange.ChunkBlockIndexStart = ChunkBlockIndex;
+ ProcessDownload(BlockDescription,
+ std::move(PayloadBuffer),
+ SubRangeStartIndex,
+ std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)},
+ OnDownloaded);
+ SubRangeCountComplete += SubRangeCount;
+ continue;
}
- NextRange.RangeLength += ChunkCompressedLength;
- NextRange.ChunkBlockIndexCount++;
- ChunkBlockIndex++;
- CurrentOffset += ChunkCompressedLength;
- NeedBlockChunkIndexOffset++;
}
else
{
- ZEN_ASSERT(false);
- }
- }
- if (NextRange.RangeLength > 0)
- {
- BlockRanges.push_back(NextRange);
- }
- }
- ZEN_ASSERT(!BlockRanges.empty());
-
- OutTotalWantedChunksSize =
- std::accumulate(BlockRanges.begin(), BlockRanges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) {
- return Current + Range.RangeLength;
- });
+ auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount);
- double RangeWantedPercent = (OutTotalWantedChunksSize * 100.0) / TotalBlockSize;
-
- if (BlockRanges.size() == 1)
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Range request of {} ({:.2f}%) using single range from block {} ({}) as is",
- NiceBytes(OutTotalWantedChunksSize),
- RangeWantedPercent,
- BlockDescription.BlockHash,
- NiceBytes(TotalBlockSize));
+ BuildStorageCache::BuildBlobRanges RangeBuffers =
+ m_Storage.CacheStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges);
+ if (m_AbortFlag)
+ {
+ break;
+ }
+ if (RangeBuffers.PayloadBuffer)
+ {
+ if (RangeBuffers.Ranges.empty())
+ {
+ SubRangeCount = Ranges.size() - SubRangeCountComplete;
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ OnDownloaded);
+ SubRangeCountComplete += SubRangeCount;
+ continue;
+ }
+ else if (RangeBuffers.Ranges.size() == SubRangeCount)
+ {
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangeBuffers.Ranges,
+ OnDownloaded);
+ SubRangeCountComplete += SubRangeCount;
+ continue;
+ }
+ }
+ }
}
- return BlockRanges;
- }
- if (LimitToSingleRange)
- {
- const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges);
- if (m_Options.IsVerbose)
- {
- const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize;
- const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength;
+ size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest);
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) limited to single block range {} ({:.2f}%) wasting "
- "{:.2f}% ({})",
- NiceBytes(OutTotalWantedChunksSize),
- RangeWantedPercent,
- BlockRanges.size(),
- BlockDescription.BlockHash,
- NiceBytes(TotalBlockSize),
- NiceBytes(MergedRange.RangeLength),
- RangeRequestedPercent,
- WastedPercent,
- NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize));
- }
- return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange);
- }
+ auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount);
- if (RangeWantedPercent > FullBlockRangePercentLimit)
- {
- const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges);
- if (m_Options.IsVerbose)
+ BuildStorageBase::BuildBlobRanges RangeBuffers =
+ m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges);
+ if (m_AbortFlag)
{
- const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize;
- const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength;
-
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) exceeds {}%. Merged to single block range {} "
- "({:.2f}%) wasting {:.2f}% ({})",
- NiceBytes(OutTotalWantedChunksSize),
- RangeWantedPercent,
- BlockRanges.size(),
- BlockDescription.BlockHash,
- NiceBytes(TotalBlockSize),
- FullBlockRangePercentLimit,
- NiceBytes(MergedRange.RangeLength),
- RangeRequestedPercent,
- WastedPercent,
- NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize));
+ break;
}
- return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange);
- }
-
- std::vector<BlockRangeDescriptor> CollapsedBlockRanges = CollapseBlockRanges(16u * 1024u, BlockRanges);
- while (GetBlockRangeLimitForRange(ForceMergeLimits, TotalBlockSize, CollapsedBlockRanges))
- {
- CollapsedBlockRanges = CollapseBlockRanges(CalculateNextGap(CollapsedBlockRanges), CollapsedBlockRanges);
- }
-
- const std::uint64_t WantedCollapsedSize =
- std::accumulate(CollapsedBlockRanges.begin(),
- CollapsedBlockRanges.end(),
- uint64_t(0),
- [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; });
-
- const double CollapsedRangeRequestedPercent = (WantedCollapsedSize * 100.0) / TotalBlockSize;
-
- if (m_Options.IsVerbose)
- {
- const double WastedPercent = ((WantedCollapsedSize - OutTotalWantedChunksSize) * 100.0) / WantedCollapsedSize;
-
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) collapsed to {} {:.2f}% using {} ranges wasting {:.2f}% "
- "({})",
- NiceBytes(OutTotalWantedChunksSize),
- RangeWantedPercent,
- BlockRanges.size(),
- BlockDescription.BlockHash,
- NiceBytes(TotalBlockSize),
- NiceBytes(WantedCollapsedSize),
- CollapsedRangeRequestedPercent,
- CollapsedBlockRanges.size(),
- WastedPercent,
- NiceBytes(WantedCollapsedSize - OutTotalWantedChunksSize));
- }
- return CollapsedBlockRanges;
-}
-
-void
-BuildsOperationUpdateFolder::DownloadPartialBlock(
- const BlockRangeDescriptor BlockRange,
- const BlobsExistsResult& ExistsResult,
- std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath)>&& OnDownloaded)
-{
- const uint32_t BlockIndex = BlockRange.BlockIndex;
-
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
-
- IoBuffer BlockBuffer;
- if (m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash))
- {
- BlockBuffer =
- m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength);
- }
- if (!BlockBuffer)
- {
- BlockBuffer =
- m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength);
- }
- if (!BlockBuffer)
- {
- throw std::runtime_error(fmt::format("Block {} is missing when fetching range {} -> {}",
- BlockDescription.BlockHash,
- BlockRange.RangeStart,
- BlockRange.RangeStart + BlockRange.RangeLength));
- }
- if (!m_AbortFlag)
- {
- uint64_t BlockSize = BlockBuffer.GetSize();
- m_DownloadStats.DownloadedBlockCount++;
- m_DownloadStats.DownloadedBlockByteCount += BlockSize;
- m_DownloadStats.RequestsCompleteCount++;
-
- std::filesystem::path BlockChunkPath;
-
- // Check if the dowloaded block is file based and we can move it directly without rewriting it
+ if (RangeBuffers.PayloadBuffer)
{
- IoBufferFileReference FileRef;
- if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == BlockSize))
+ if (RangeBuffers.Ranges.empty())
{
- ZEN_TRACE_CPU("MoveTempPartialBlock");
+ // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3
+ // Upload to cache (if enabled) and use the whole payload for the remaining ranges
- std::error_code Ec;
- std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec);
- if (!Ec)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- BlockBuffer.SetDeleteOnClose(false);
- BlockBuffer = {};
- BlockChunkPath = m_TempBlockFolderPath /
- fmt::format("{}_{:x}_{:x}", BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength);
- RenameFile(TempBlobPath, BlockChunkPath, Ec);
- if (Ec)
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlockDescription.BlockHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer}));
+ if (m_AbortFlag)
{
- BlockChunkPath = std::filesystem::path{};
-
- // Re-open the temp file again
- BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete);
- BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true);
- BlockBuffer.SetDeleteOnClose(true);
+ break;
}
}
- }
- }
- if (BlockChunkPath.empty() && (BlockSize > m_Options.MaximumInMemoryPayloadSize))
- {
- ZEN_TRACE_CPU("WriteTempPartialBlock");
- // Could not be moved and rather large, lets store it on disk
- BlockChunkPath = m_TempBlockFolderPath /
- fmt::format("{}_{:x}_{:x}", BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength);
- TemporaryFile::SafeWriteFile(BlockChunkPath, BlockBuffer);
- BlockBuffer = {};
+ SubRangeCount = Ranges.size() - SubRangeCountComplete;
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ OnDownloaded);
+ }
+ else
+ {
+ if (RangeBuffers.Ranges.size() != SubRanges.size())
+ {
+ throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges",
+ SubRanges.size(),
+ BlockDescription.BlockHash,
+ RangeBuffers.Ranges.size()));
+ }
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangeBuffers.Ranges,
+ OnDownloaded);
+ }
}
- if (!m_AbortFlag)
+ else
{
- OnDownloaded(std::move(BlockBuffer), std::move(BlockChunkPath));
+ throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount));
}
+
+ SubRangeCountComplete += SubRangeCount;
}
}
@@ -4083,7 +3882,8 @@ BuildsOperationUpdateFolder::WriteSequenceChunkToCache(BufferedWriteFileCache::L
}
bool
-BuildsOperationUpdateFolder::GetBlockWriteOps(std::span<const IoHash> ChunkRawHashes,
+BuildsOperationUpdateFolder::GetBlockWriteOps(const IoHash& BlockRawHash,
+ std::span<const IoHash> ChunkRawHashes,
std::span<const uint32_t> ChunkCompressedLengths,
std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters,
std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags,
@@ -4115,9 +3915,34 @@ BuildsOperationUpdateFolder::GetBlockWriteOps(std::span<const IoHash> ChunkR
uint64_t VerifyChunkSize;
CompressedBuffer CompressedChunk =
CompressedBuffer::FromCompressed(SharedBuffer::MakeView(ChunkMemoryView), VerifyChunkHash, VerifyChunkSize);
- ZEN_ASSERT(CompressedChunk);
- ZEN_ASSERT(VerifyChunkHash == ChunkHash);
- ZEN_ASSERT(VerifyChunkSize == m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]);
+ if (!CompressedChunk)
+ {
+ throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} is not a valid compressed buffer",
+ ChunkHash,
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockRawHash));
+ }
+ if (VerifyChunkHash != ChunkHash)
+ {
+ throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} has a mismatching content hash {}",
+ ChunkHash,
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockRawHash,
+ VerifyChunkHash));
+ }
+ if (VerifyChunkSize != m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex])
+ {
+ throw std::runtime_error(
+ fmt::format("Chunk {} at {}, size {} in block {} has a mismatching raw size {}, expected {}",
+ ChunkHash,
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockRawHash,
+ VerifyChunkSize,
+ m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]));
+ }
OodleCompressor ChunkCompressor;
OodleCompressionLevel ChunkCompressionLevel;
@@ -4138,7 +3963,18 @@ BuildsOperationUpdateFolder::GetBlockWriteOps(std::span<const IoHash> ChunkR
{
Decompressed = CompressedChunk.Decompress().AsIoBuffer();
}
- ZEN_ASSERT(Decompressed.GetSize() == m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]);
+
+ if (Decompressed.GetSize() != m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex])
+ {
+ throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} decompressed to size {}, expected {}",
+ ChunkHash,
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockRawHash,
+ Decompressed.GetSize(),
+ m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]));
+ }
+
ZEN_ASSERT_SLOW(ChunkHash == IoHash::HashBuffer(Decompressed));
for (const ChunkedContentLookup::ChunkSequenceLocation* Target : ChunkTargetPtrs)
{
@@ -4237,7 +4073,8 @@ BuildsOperationUpdateFolder::WriteChunksBlockToCache(const ChunkBlockDescription
const std::vector<uint32_t> ChunkCompressedLengths =
ReadChunkBlockHeader(BlockView.Mid(CompressedBuffer::GetHeaderSizeForNoneEncoder()), HeaderSize);
- if (GetBlockWriteOps(BlockDescription.ChunkRawHashes,
+ if (GetBlockWriteOps(BlockDescription.BlockHash,
+ BlockDescription.ChunkRawHashes,
ChunkCompressedLengths,
SequenceIndexChunksLeftToWriteCounters,
RemoteChunkIndexNeedsCopyFromSourceFlags,
@@ -4252,7 +4089,8 @@ BuildsOperationUpdateFolder::WriteChunksBlockToCache(const ChunkBlockDescription
return false;
}
- if (GetBlockWriteOps(BlockDescription.ChunkRawHashes,
+ if (GetBlockWriteOps(BlockDescription.BlockHash,
+ BlockDescription.ChunkRawHashes,
BlockDescription.ChunkCompressedLengths,
SequenceIndexChunksLeftToWriteCounters,
RemoteChunkIndexNeedsCopyFromSourceFlags,
@@ -4283,7 +4121,8 @@ BuildsOperationUpdateFolder::WritePartialBlockChunksToCache(const ChunkBlockDesc
const MemoryView BlockView = BlockMemoryBuffer.GetView();
BlockWriteOps Ops;
- if (GetBlockWriteOps(BlockDescription.ChunkRawHashes,
+ if (GetBlockWriteOps(BlockDescription.BlockHash,
+ BlockDescription.ChunkRawHashes,
BlockDescription.ChunkCompressedLengths,
SequenceIndexChunksLeftToWriteCounters,
RemoteChunkIndexNeedsCopyFromSourceFlags,
@@ -5156,12 +4995,12 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent&
const IoHash& BlockHash = OutBlocks.BlockDescriptions[BlockIndex].BlockHash;
const uint64_t CompressedBlockSize = Payload.GetCompressedSize();
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId,
- BlockHash,
- ZenContentType::kCompressedBinary,
- Payload.GetCompressed());
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlockHash,
+ ZenContentType::kCompressedBinary,
+ Payload.GetCompressed());
}
m_Storage.BuildStorage->PutBuildBlob(m_BuildId,
@@ -5179,11 +5018,11 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent&
OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
}
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
}
bool MetadataSucceeded =
@@ -5334,6 +5173,13 @@ BuildsOperationUploadFolder::FetchChunk(const ChunkedFolderContent& Content,
ZEN_ASSERT(!ChunkLocations.empty());
CompositeBuffer Chunk =
OpenFileCache.GetRange(ChunkLocations[0].SequenceIndex, ChunkLocations[0].Offset, Content.ChunkedContent.ChunkRawSizes[ChunkIndex]);
+ if (!Chunk)
+ {
+ throw std::runtime_error(fmt::format("Unable to read chunk at {}, size {} from '{}'",
+ ChunkLocations[0].Offset,
+ Content.ChunkedContent.ChunkRawSizes[ChunkIndex],
+ Content.Paths[Lookup.SequenceIndexFirstPathIndex[ChunkLocations[0].SequenceIndex]]));
+ }
ZEN_ASSERT_SLOW(IoHash::HashBuffer(Chunk) == ChunkHash);
return Chunk;
};
@@ -5362,10 +5208,7 @@ BuildsOperationUploadFolder::GenerateBlock(const ChunkedFolderContent& Content,
Content.ChunkedContent.ChunkHashes[ChunkIndex],
[this, &Content, &Lookup, &OpenFileCache, ChunkIndex](const IoHash& ChunkHash) -> std::pair<uint64_t, CompressedBuffer> {
CompositeBuffer Chunk = FetchChunk(Content, Lookup, ChunkHash, OpenFileCache);
- if (!Chunk)
- {
- ZEN_ASSERT(false);
- }
+ ZEN_ASSERT(Chunk);
uint64_t RawSize = Chunk.GetSize();
const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock &&
@@ -6023,11 +5866,11 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
{
const CbObject BlockMetaData =
BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]);
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
}
bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
if (MetadataSucceeded)
@@ -6221,9 +6064,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
const CbObject BlockMetaData =
BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]);
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
}
m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
if (m_Options.IsVerbose)
@@ -6237,11 +6080,11 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
UploadedBlockSize += PayloadSize;
TempUploadStats.BlocksBytes += PayloadSize;
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
}
bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
if (MetadataSucceeded)
@@ -6304,9 +6147,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
const uint64_t PayloadSize = Payload.GetSize();
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
}
if (PayloadSize >= LargeAttachmentSize)
@@ -7050,14 +6893,14 @@ BuildsOperationPrimeCache::Execute()
std::vector<IoHash> BlobsToDownload;
BlobsToDownload.reserve(BuildBlobs.size());
- if (m_Storage.BuildCacheStorage && !BuildBlobs.empty() && !m_Options.ForceUpload)
+ if (m_Storage.CacheStorage && !BuildBlobs.empty() && !m_Options.ForceUpload)
{
ZEN_TRACE_CPU("BlobCacheExistCheck");
Stopwatch Timer;
const std::vector<IoHash> BlobHashes(BuildBlobs.begin(), BuildBlobs.end());
const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult =
- m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes);
+ m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes);
if (CacheExistsResult.size() == BlobHashes.size())
{
@@ -7104,33 +6947,33 @@ BuildsOperationPrimeCache::Execute()
for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++)
{
- Work.ScheduleWork(
- m_NetworkPool,
- [this,
- &Work,
- &BlobsToDownload,
- BlobCount,
- &LooseChunkRawSizes,
- &CompletedDownloadCount,
- &FilteredDownloadedBytesPerSecond,
- &MultipartAttachmentCount,
- BlobIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- const IoHash& BlobHash = BlobsToDownload[BlobIndex];
+ Work.ScheduleWork(m_NetworkPool,
+ [this,
+ &Work,
+ &BlobsToDownload,
+ BlobCount,
+ &LooseChunkRawSizes,
+ &CompletedDownloadCount,
+ &FilteredDownloadedBytesPerSecond,
+ &MultipartAttachmentCount,
+ BlobIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ const IoHash& BlobHash = BlobsToDownload[BlobIndex];
- bool IsLargeBlob = false;
+ bool IsLargeBlob = false;
- if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end())
- {
- IsLargeBlob = It->second >= m_Options.LargeAttachmentSize;
- }
+ if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end())
+ {
+ IsLargeBlob = It->second >= m_Options.LargeAttachmentSize;
+ }
- FilteredDownloadedBytesPerSecond.Start();
+ FilteredDownloadedBytesPerSecond.Start();
- if (IsLargeBlob)
- {
- DownloadLargeBlob(*m_Storage.BuildStorage,
+ if (IsLargeBlob)
+ {
+ DownloadLargeBlob(
+ *m_Storage.BuildStorage,
m_TempPath,
m_BuildId,
BlobHash,
@@ -7146,12 +6989,12 @@ BuildsOperationPrimeCache::Execute()
if (!m_AbortFlag)
{
- if (Payload && m_Storage.BuildCacheStorage)
+ if (Payload && m_Storage.CacheStorage)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId,
- BlobHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(Payload)));
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlobHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(Payload)));
}
}
CompletedDownloadCount++;
@@ -7160,32 +7003,32 @@ BuildsOperationPrimeCache::Execute()
FilteredDownloadedBytesPerSecond.Stop();
}
});
- }
- else
- {
- IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash);
- m_DownloadStats.DownloadedBlockCount++;
- m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
- m_DownloadStats.RequestsCompleteCount++;
+ }
+ else
+ {
+ IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash);
+ m_DownloadStats.DownloadedBlockCount++;
+ m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
+ m_DownloadStats.RequestsCompleteCount++;
- if (!m_AbortFlag)
- {
- if (Payload && m_Storage.BuildCacheStorage)
- {
- m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId,
- BlobHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(std::move(Payload))));
- }
- }
- CompletedDownloadCount++;
- if (CompletedDownloadCount == BlobCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
- }
- }
- });
+ if (!m_AbortFlag)
+ {
+ if (Payload && m_Storage.CacheStorage)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlobHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(std::move(Payload))));
+ }
+ }
+ CompletedDownloadCount++;
+ if (CompletedDownloadCount == BlobCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
+ }
+ }
+ });
}
Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
@@ -7197,10 +7040,10 @@ BuildsOperationPrimeCache::Execute()
std::string DownloadRateString = (CompletedDownloadCount == BlobCount)
? ""
: fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8));
- std::string UploadDetails = m_Storage.BuildCacheStorage ? fmt::format(" {} ({}) uploaded.",
- m_StorageCacheStats.PutBlobCount.load(),
- NiceBytes(m_StorageCacheStats.PutBlobByteCount.load()))
- : "";
+ std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.",
+ m_StorageCacheStats.PutBlobCount.load(),
+ NiceBytes(m_StorageCacheStats.PutBlobByteCount.load()))
+ : "";
std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}",
CompletedDownloadCount.load(),
@@ -7225,13 +7068,13 @@ BuildsOperationPrimeCache::Execute()
return;
}
- if (m_Storage.BuildCacheStorage)
+ if (m_Storage.CacheStorage)
{
- m_Storage.BuildCacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool {
+ m_Storage.CacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool {
ZEN_UNUSED(Remaining);
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheName);
+ ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name);
}
return !m_AbortFlag;
});
@@ -7431,16 +7274,31 @@ GetRemoteContent(OperationLogOutput& Output,
// TODO: GetBlockDescriptions for all BlockRawHashes in one go - check for local block descriptions when we cache them
{
+ if (!IsQuiet)
+ {
+ ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size());
+ }
+
+ Stopwatch GetBlockMetadataTimer;
+
bool AttemptFallback = false;
OutBlockDescriptions = GetBlockDescriptions(Output,
*Storage.BuildStorage,
- Storage.BuildCacheStorage.get(),
+ Storage.CacheStorage.get(),
BuildId,
- BuildPartId,
BlockRawHashes,
AttemptFallback,
IsQuiet,
IsVerbose);
+
+ if (!IsQuiet)
+ {
+ ZEN_OPERATION_LOG_INFO(Output,
+ "GetBlockMetadata for {} took {}. Found {} blocks",
+ BuildPartId,
+ NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()),
+ OutBlockDescriptions.size());
+ }
}
CalculateLocalChunkOrders(AbsoluteChunkOrders,
@@ -7989,6 +7847,8 @@ namespace buildstorageoperations_testutils {
} // namespace buildstorageoperations_testutils
+TEST_SUITE_BEGIN("remotestore.buildstorageoperations");
+
TEST_CASE("buildstorageoperations.upload.folder")
{
using namespace buildstorageoperations_testutils;
@@ -8176,106 +8036,270 @@ TEST_CASE("buildstorageoperations.memorychunkingcache")
TEST_CASE("buildstorageoperations.upload.multipart")
{
- using namespace buildstorageoperations_testutils;
+ // Disabled since it relies on authentication and specific block being present in cloud storage
+ if (false)
+ {
+ using namespace buildstorageoperations_testutils;
- FastRandom BaseRandom;
+ FastRandom BaseRandom;
- const size_t FileCount = 11;
+ const size_t FileCount = 11;
- const std::string Paths[FileCount] = {{"file_1"},
- {"file_2.exe"},
- {"file_3.txt"},
- {"dir_1/dir1_file_1.exe"},
- {"dir_1/dir1_file_2.pdb"},
- {"dir_1/dir1_file_3.txt"},
- {"dir_2/dir2_dir1/dir2_dir1_file_1.exe"},
- {"dir_2/dir2_dir1/dir2_dir1_file_2.pdb"},
- {"dir_2/dir2_dir1/dir2_dir1_file_3.dll"},
- {"dir_2/dir2_dir2/dir2_dir2_file_1.txt"},
- {"dir_2/dir2_dir2/dir2_dir2_file_2.json"}};
- const uint64_t Sizes[FileCount] =
- {6u * 1024u, 0, 798, 19u * 1024u, 7u * 1024u, 93, 31u * 1024u, 17u * 1024u, 13u * 1024u, 2u * 1024u, 3u * 1024u};
+ const std::string Paths[FileCount] = {{"file_1"},
+ {"file_2.exe"},
+ {"file_3.txt"},
+ {"dir_1/dir1_file_1.exe"},
+ {"dir_1/dir1_file_2.pdb"},
+ {"dir_1/dir1_file_3.txt"},
+ {"dir_2/dir2_dir1/dir2_dir1_file_1.exe"},
+ {"dir_2/dir2_dir1/dir2_dir1_file_2.pdb"},
+ {"dir_2/dir2_dir1/dir2_dir1_file_3.dll"},
+ {"dir_2/dir2_dir2/dir2_dir2_file_1.txt"},
+ {"dir_2/dir2_dir2/dir2_dir2_file_2.json"}};
+ const uint64_t Sizes[FileCount] =
+ {6u * 1024u, 0, 798, 19u * 1024u, 7u * 1024u, 93, 31u * 1024u, 17u * 1024u, 13u * 1024u, 2u * 1024u, 3u * 1024u};
- ScopedTemporaryDirectory SourceFolder;
- TestState State(SourceFolder.Path());
- State.Initialize();
- State.CreateSourceData("source", Paths, Sizes);
+ ScopedTemporaryDirectory SourceFolder;
+ TestState State(SourceFolder.Path());
+ State.Initialize();
+ State.CreateSourceData("source", Paths, Sizes);
- std::span<const std::string> ManifestFiles1(Paths);
- ManifestFiles1 = ManifestFiles1.subspan(0, FileCount / 2);
+ std::span<const std::string> ManifestFiles1(Paths);
+ ManifestFiles1 = ManifestFiles1.subspan(0, FileCount / 2);
- std::span<const uint64_t> ManifestSizes1(Sizes);
- ManifestSizes1 = ManifestSizes1.subspan(0, FileCount / 2);
+ std::span<const uint64_t> ManifestSizes1(Sizes);
+ ManifestSizes1 = ManifestSizes1.subspan(0, FileCount / 2);
- std::span<const std::string> ManifestFiles2(Paths);
- ManifestFiles2 = ManifestFiles2.subspan(FileCount / 2 - 1);
+ std::span<const std::string> ManifestFiles2(Paths);
+ ManifestFiles2 = ManifestFiles2.subspan(FileCount / 2 - 1);
- std::span<const uint64_t> ManifestSizes2(Sizes);
- ManifestSizes2 = ManifestSizes2.subspan(FileCount / 2 - 1);
+ std::span<const uint64_t> ManifestSizes2(Sizes);
+ ManifestSizes2 = ManifestSizes2.subspan(FileCount / 2 - 1);
- const Oid BuildPart1Id = Oid::NewOid();
- const std::string BuildPart1Name = "part1";
- const Oid BuildPart2Id = Oid::NewOid();
- const std::string BuildPart2Name = "part2";
- {
- CbObjectWriter Writer;
- Writer.BeginObject("parts"sv);
+ const Oid BuildPart1Id = Oid::NewOid();
+ const std::string BuildPart1Name = "part1";
+ const Oid BuildPart2Id = Oid::NewOid();
+ const std::string BuildPart2Name = "part2";
{
- Writer.BeginObject(BuildPart1Name);
+ CbObjectWriter Writer;
+ Writer.BeginObject("parts"sv);
{
- Writer.AddObjectId("partId"sv, BuildPart1Id);
- Writer.BeginArray("files"sv);
- for (const std::string& ManifestFile : ManifestFiles1)
+ Writer.BeginObject(BuildPart1Name);
{
- Writer.AddString(ManifestFile);
+ Writer.AddObjectId("partId"sv, BuildPart1Id);
+ Writer.BeginArray("files"sv);
+ for (const std::string& ManifestFile : ManifestFiles1)
+ {
+ Writer.AddString(ManifestFile);
+ }
+ Writer.EndArray(); // files
+ }
+ Writer.EndObject(); // part1
+
+ Writer.BeginObject(BuildPart2Name);
+ {
+ Writer.AddObjectId("partId"sv, BuildPart2Id);
+ Writer.BeginArray("files"sv);
+ for (const std::string& ManifestFile : ManifestFiles2)
+ {
+ Writer.AddString(ManifestFile);
+ }
+ Writer.EndArray(); // files
}
- Writer.EndArray(); // files
+ Writer.EndObject(); // part2
+ }
+ Writer.EndObject(); // parts
+
+ ExtendableStringBuilder<1024> Manifest;
+ CompactBinaryToJson(Writer.Save(), Manifest);
+ WriteFile(State.RootPath / "manifest.json", IoBuffer(IoBuffer::Wrap, Manifest.Data(), Manifest.Size()));
+ }
+
+ const Oid BuildId = Oid::NewOid();
+
+ auto Result = State.Upload(BuildId, {}, {}, "source", State.RootPath / "manifest.json");
+
+ CHECK_EQ(Result.size(), 2u);
+ CHECK_EQ(Result[0].first, BuildPart1Id);
+ CHECK_EQ(Result[0].second, BuildPart1Name);
+ CHECK_EQ(Result[1].first, BuildPart2Id);
+ CHECK_EQ(Result[1].second, BuildPart2Name);
+ State.ValidateUpload(BuildId, Result);
+
+ FolderContent DownloadContent = State.Download(BuildId, Oid::Zero, {}, "download", /* Append */ false);
+ State.ValidateDownload(Paths, Sizes, "source", "download", DownloadContent);
+
+ FolderContent Part1DownloadContent = State.Download(BuildId, BuildPart1Id, {}, "download_part1", /* Append */ false);
+ State.ValidateDownload(ManifestFiles1, ManifestSizes1, "source", "download_part1", Part1DownloadContent);
+
+ FolderContent Part2DownloadContent = State.Download(BuildId, Oid::Zero, BuildPart2Name, "download_part2", /* Append */ false);
+ State.ValidateDownload(ManifestFiles2, ManifestSizes2, "source", "download_part2", Part2DownloadContent);
+
+ (void)State.Download(BuildId, BuildPart1Id, BuildPart1Name, "download_part1+2", /* Append */ false);
+ FolderContent Part1And2DownloadContent = State.Download(BuildId, BuildPart2Id, {}, "download_part1+2", /* Append */ true);
+ State.ValidateDownload(Paths, Sizes, "source", "download_part1+2", Part1And2DownloadContent);
+ }
+}
+
+TEST_CASE("buildstorageoperations.partial.block.download" * doctest::skip(true))
+{
+ const std::string OidcExecutableName = "OidcToken" ZEN_EXE_SUFFIX_LITERAL;
+ std::filesystem::path OidcTokenExePath = (GetRunningExecutablePath().parent_path() / OidcExecutableName).make_preferred();
+
+ HttpClientSettings ClientSettings{
+ .LogCategory = "httpbuildsclient",
+ .AccessTokenProvider =
+ httpclientauth::CreateFromOidcTokenExecutable(OidcTokenExePath, "https://jupiter.devtools.epicgames.com", true, false, false),
+ .AssumeHttp2 = false,
+ .AllowResume = true,
+ .RetryCount = 0,
+ .Verbose = false};
+
+ HttpClient HttpClient("https://euc.jupiter.devtools.epicgames.com", ClientSettings);
+
+ const std::string_view Namespace = "fortnite.oplog";
+ const std::string_view Bucket = "fortnitegame.staged-build.fortnite-main.ps4-client";
+ const Oid BuildId = Oid::FromHexString("09a76ea92ad301d4724fafad");
+
+ {
+ HttpClient::Response Response = HttpClient.Get(fmt::format("/api/v2/builds/{}/{}/{}", Namespace, Bucket, BuildId),
+ HttpClient::Accept(ZenContentType::kCbObject));
+ CbValidateError ValidateResult = CbValidateError::None;
+ CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(Response.ResponsePayload), ValidateResult);
+ REQUIRE(ValidateResult == CbValidateError::None);
+ }
+
+ std::vector<ChunkBlockDescription> BlockDescriptions;
+ {
+ CbObjectWriter Request;
+
+ Request.BeginArray("blocks"sv);
+ {
+ Request.AddHash(IoHash::FromHexString("7c353ed782675a5e8f968e61e51fc797ecdc2882"));
+ }
+ Request.EndArray();
+
+ IoBuffer Payload = Request.Save().GetBuffer().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCbObject);
+
+ HttpClient::Response BlockDescriptionsResponse =
+ HttpClient.Post(fmt::format("/api/v2/builds/{}/{}/{}/blocks/getBlockMetadata", Namespace, Bucket, BuildId),
+ Payload,
+ HttpClient::Accept(ZenContentType::kCbObject));
+ REQUIRE(BlockDescriptionsResponse.IsSuccess());
+
+ CbValidateError ValidateResult = CbValidateError::None;
+ CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(BlockDescriptionsResponse.ResponsePayload), ValidateResult);
+ REQUIRE(ValidateResult == CbValidateError::None);
+
+ {
+ CbArrayView BlocksArray = Object["blocks"sv].AsArrayView();
+ for (CbFieldView Block : BlocksArray)
+ {
+ ChunkBlockDescription Description = ParseChunkBlockDescription(Block.AsObjectView());
+ BlockDescriptions.emplace_back(std::move(Description));
}
- Writer.EndObject(); // part1
+ }
+ }
+
+ REQUIRE(!BlockDescriptions.empty());
- Writer.BeginObject(BuildPart2Name);
+ const IoHash BlockHash = BlockDescriptions.back().BlockHash;
+
+ const ChunkBlockDescription& BlockDescription = BlockDescriptions.front();
+ REQUIRE(!BlockDescription.ChunkRawHashes.empty());
+ REQUIRE(!BlockDescription.ChunkCompressedLengths.empty());
+
+ std::vector<std::pair<uint64_t, uint64_t>> ChunkOffsetAndSizes;
+ uint64_t Offset = gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize);
+
+ for (uint32_t ChunkCompressedSize : BlockDescription.ChunkCompressedLengths)
+ {
+ ChunkOffsetAndSizes.push_back(std::make_pair(Offset, ChunkCompressedSize));
+ Offset += ChunkCompressedSize;
+ }
+
+ ScopedTemporaryDirectory SourceFolder;
+
+ auto Validate = [&](std::span<const uint32_t> ChunkIndexesToFetch) {
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+ for (uint32_t ChunkIndex : ChunkIndexesToFetch)
+ {
+ Ranges.push_back(ChunkOffsetAndSizes[ChunkIndex]);
+ }
+
+ HttpClient::KeyValueMap Headers;
+ if (!Ranges.empty())
+ {
+ ExtendableStringBuilder<512> SB;
+ for (const std::pair<uint64_t, uint64_t>& R : Ranges)
{
- Writer.AddObjectId("partId"sv, BuildPart2Id);
- Writer.BeginArray("files"sv);
- for (const std::string& ManifestFile : ManifestFiles2)
+ if (SB.Size() > 0)
{
- Writer.AddString(ManifestFile);
+ SB << ", ";
}
- Writer.EndArray(); // files
+ SB << R.first << "-" << R.first + R.second - 1;
}
- Writer.EndObject(); // part2
+ Headers.Entries.insert({"Range", fmt::format("bytes={}", SB.ToView())});
}
- Writer.EndObject(); // parts
- ExtendableStringBuilder<1024> Manifest;
- CompactBinaryToJson(Writer.Save(), Manifest);
- WriteFile(State.RootPath / "manifest.json", IoBuffer(IoBuffer::Wrap, Manifest.Data(), Manifest.Size()));
- }
+ HttpClient::Response GetBlobRangesResponse = HttpClient.Download(
+ fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}?supportsRedirect=false", Namespace, Bucket, BuildId, BlockHash),
+ SourceFolder.Path(),
+ Headers);
- const Oid BuildId = Oid::NewOid();
+ REQUIRE(GetBlobRangesResponse.IsSuccess());
+ [[maybe_unused]] MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView();
- auto Result = State.Upload(BuildId, {}, {}, "source", State.RootPath / "manifest.json");
+ std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = GetBlobRangesResponse.GetRanges(Ranges);
+ if (PayloadRanges.empty())
+ {
+ // We got the whole blob, use the ranges as is
+ PayloadRanges = Ranges;
+ }
- CHECK_EQ(Result.size(), 2u);
- CHECK_EQ(Result[0].first, BuildPart1Id);
- CHECK_EQ(Result[0].second, BuildPart1Name);
- CHECK_EQ(Result[1].first, BuildPart2Id);
- CHECK_EQ(Result[1].second, BuildPart2Name);
- State.ValidateUpload(BuildId, Result);
+ REQUIRE(PayloadRanges.size() == Ranges.size());
- FolderContent DownloadContent = State.Download(BuildId, Oid::Zero, {}, "download", /* Append */ false);
- State.ValidateDownload(Paths, Sizes, "source", "download", DownloadContent);
+ for (uint32_t RangeIndex = 0; RangeIndex < PayloadRanges.size(); RangeIndex++)
+ {
+ const std::pair<uint64_t, uint64_t>& PayloadRange = PayloadRanges[RangeIndex];
+
+ CHECK_EQ(PayloadRange.second, Ranges[RangeIndex].second);
- FolderContent Part1DownloadContent = State.Download(BuildId, BuildPart1Id, {}, "download_part1", /* Append */ false);
- State.ValidateDownload(ManifestFiles1, ManifestSizes1, "source", "download_part1", Part1DownloadContent);
+ IoBuffer ChunkPayload(GetBlobRangesResponse.ResponsePayload, PayloadRange.first, PayloadRange.second);
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(ChunkPayload), RawHash, RawSize);
+ CHECK(CompressedChunk);
+ CHECK_EQ(RawHash, BlockDescription.ChunkRawHashes[ChunkIndexesToFetch[RangeIndex]]);
+ CHECK_EQ(RawSize, BlockDescription.ChunkRawLengths[ChunkIndexesToFetch[RangeIndex]]);
+ }
+ };
- FolderContent Part2DownloadContent = State.Download(BuildId, Oid::Zero, BuildPart2Name, "download_part2", /* Append */ false);
- State.ValidateDownload(ManifestFiles2, ManifestSizes2, "source", "download_part2", Part2DownloadContent);
+ {
+ // Single
+ std::vector<uint32_t> ChunkIndexesToFetch{uint32_t(BlockDescription.ChunkCompressedLengths.size() / 2)};
+ Validate(ChunkIndexesToFetch);
+ }
+ {
+ // Many
+ std::vector<uint32_t> ChunkIndexesToFetch;
+ for (uint32_t Index = 0; Index < BlockDescription.ChunkCompressedLengths.size() / 16; Index++)
+ {
+ ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7));
+ ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7 + 1));
+ ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7 + 3));
+ }
+ Validate(ChunkIndexesToFetch);
+ }
- (void)State.Download(BuildId, BuildPart1Id, BuildPart1Name, "download_part1+2", /* Append */ false);
- FolderContent Part1And2DownloadContent = State.Download(BuildId, BuildPart2Id, {}, "download_part1+2", /* Append */ true);
- State.ValidateDownload(Paths, Sizes, "source", "download_part1+2", Part1And2DownloadContent);
+ {
+ // First and last
+ std::vector<uint32_t> ChunkIndexesToFetch{0, uint32_t(BlockDescription.ChunkCompressedLengths.size() - 1)};
+ Validate(ChunkIndexesToFetch);
+ }
}
+TEST_SUITE_END();
void
buildstorageoperations_forcelink()
diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp
index 36b45e800..2ae726e29 100644
--- a/src/zenremotestore/builds/buildstorageutil.cpp
+++ b/src/zenremotestore/builds/buildstorageutil.cpp
@@ -63,11 +63,15 @@ ResolveBuildStorage(OperationLogOutput& Output,
std::string HostUrl;
std::string HostName;
+ double HostLatencySec = -1.0;
+ uint64_t HostMaxRangeCountPerRequest = 1;
std::string CacheUrl;
std::string CacheName;
- bool HostAssumeHttp2 = ClientSettings.AssumeHttp2;
- bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2;
+ bool HostAssumeHttp2 = ClientSettings.AssumeHttp2;
+ bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2;
+ double CacheLatencySec = -1.0;
+ uint64_t CacheMaxRangeCountPerRequest = 1;
JupiterServerDiscovery DiscoveryResponse;
const std::string_view DiscoveryHost = Host.empty() ? OverrideHost : Host;
@@ -98,8 +102,10 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost);
}
- HostUrl = OverrideHost;
- HostName = GetHostNameFromUrl(OverrideHost);
+ HostUrl = OverrideHost;
+ HostName = GetHostNameFromUrl(OverrideHost);
+ HostLatencySec = TestResult.LatencySeconds;
+ HostMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest;
}
else
{
@@ -134,9 +140,11 @@ ResolveBuildStorage(OperationLogOutput& Output,
ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl);
}
- HostUrl = ServerEndpoint.BaseUrl;
- HostAssumeHttp2 = ServerEndpoint.AssumeHttp2;
- HostName = ServerEndpoint.Name;
+ HostUrl = ServerEndpoint.BaseUrl;
+ HostAssumeHttp2 = ServerEndpoint.AssumeHttp2;
+ HostName = ServerEndpoint.Name;
+ HostLatencySec = TestResult.LatencySeconds;
+ HostMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest;
break;
}
else
@@ -180,9 +188,11 @@ ResolveBuildStorage(OperationLogOutput& Output,
ZEN_OPERATION_LOG_INFO(Output, "Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl);
}
- CacheUrl = CacheEndpoint.BaseUrl;
- CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2;
- CacheName = CacheEndpoint.Name;
+ CacheUrl = CacheEndpoint.BaseUrl;
+ CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2;
+ CacheName = CacheEndpoint.Name;
+ CacheLatencySec = TestResult.LatencySeconds;
+ CacheMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest;
break;
}
}
@@ -204,6 +214,7 @@ ResolveBuildStorage(OperationLogOutput& Output,
CacheUrl = ZenServerLocalHostUrl;
CacheAssumeHttp2 = false;
CacheName = "localhost";
+ CacheLatencySec = TestResult.LatencySeconds;
}
}
});
@@ -219,8 +230,10 @@ ResolveBuildStorage(OperationLogOutput& Output,
if (ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(ZenCacheHost, /*AssumeHttp2*/ false, ClientSettings.Verbose);
TestResult.Success)
{
- CacheUrl = ZenCacheHost;
- CacheName = GetHostNameFromUrl(ZenCacheHost);
+ CacheUrl = ZenCacheHost;
+ CacheName = GetHostNameFromUrl(ZenCacheHost);
+ CacheLatencySec = TestResult.LatencySeconds;
+ CacheMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest;
}
else
{
@@ -228,13 +241,34 @@ ResolveBuildStorage(OperationLogOutput& Output,
}
}
- return BuildStorageResolveResult{.HostUrl = HostUrl,
- .HostName = HostName,
- .HostAssumeHttp2 = HostAssumeHttp2,
+ return BuildStorageResolveResult{
+ .Cloud = {.Address = HostUrl,
+ .Name = HostName,
+ .AssumeHttp2 = HostAssumeHttp2,
+ .LatencySec = HostLatencySec,
+ .Caps = BuildStorageResolveResult::Capabilities{.MaxRangeCountPerRequest = HostMaxRangeCountPerRequest}},
+ .Cache = {.Address = CacheUrl,
+ .Name = CacheName,
+ .AssumeHttp2 = CacheAssumeHttp2,
+ .LatencySec = CacheLatencySec,
+ .Caps = BuildStorageResolveResult::Capabilities{.MaxRangeCountPerRequest = CacheMaxRangeCountPerRequest}}};
+}
- .CacheUrl = CacheUrl,
- .CacheName = CacheName,
- .CacheAssumeHttp2 = CacheAssumeHttp2};
+std::vector<ChunkBlockDescription>
+ParseBlockMetadatas(std::span<const CbObject> BlockMetadatas)
+{
+ std::vector<ChunkBlockDescription> UnorderedList;
+ UnorderedList.reserve(BlockMetadatas.size());
+ for (size_t CacheBlockMetadataIndex = 0; CacheBlockMetadataIndex < BlockMetadatas.size(); CacheBlockMetadataIndex++)
+ {
+ const CbObject& CacheBlockMetadata = BlockMetadatas[CacheBlockMetadataIndex];
+ ChunkBlockDescription Description = ParseChunkBlockDescription(CacheBlockMetadata);
+ if (Description.BlockHash != IoHash::Zero)
+ {
+ UnorderedList.emplace_back(std::move(Description));
+ }
+ }
+ return UnorderedList;
}
std::vector<ChunkBlockDescription>
@@ -242,7 +276,6 @@ GetBlockDescriptions(OperationLogOutput& Output,
BuildStorageBase& Storage,
BuildStorageCache* OptionalCacheStorage,
const Oid& BuildId,
- const Oid& BuildPartId,
std::span<const IoHash> BlockRawHashes,
bool AttemptFallback,
bool IsQuiet,
@@ -250,37 +283,20 @@ GetBlockDescriptions(OperationLogOutput& Output,
{
using namespace std::literals;
- if (!IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size());
- }
-
- Stopwatch GetBlockMetadataTimer;
-
std::vector<ChunkBlockDescription> UnorderedList;
tsl::robin_map<IoHash, size_t, IoHash::Hasher> BlockDescriptionLookup;
if (OptionalCacheStorage && !BlockRawHashes.empty())
{
std::vector<CbObject> CacheBlockMetadatas = OptionalCacheStorage->GetBlobMetadatas(BuildId, BlockRawHashes);
- UnorderedList.reserve(CacheBlockMetadatas.size());
- for (size_t CacheBlockMetadataIndex = 0; CacheBlockMetadataIndex < CacheBlockMetadatas.size(); CacheBlockMetadataIndex++)
+ if (!CacheBlockMetadatas.empty())
{
- const CbObject& CacheBlockMetadata = CacheBlockMetadatas[CacheBlockMetadataIndex];
- ChunkBlockDescription Description = ParseChunkBlockDescription(CacheBlockMetadata);
- if (Description.BlockHash == IoHash::Zero)
+ UnorderedList = ParseBlockMetadatas(CacheBlockMetadatas);
+ for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++)
{
- ZEN_OPERATION_LOG_WARN(Output, "Unexpected/invalid block metadata received from remote cache, skipping block");
- }
- else
- {
- UnorderedList.emplace_back(std::move(Description));
+ const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex];
+ BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex);
}
}
- for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++)
- {
- const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex];
- BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex);
- }
}
if (UnorderedList.size() < BlockRawHashes.size())
@@ -346,15 +362,6 @@ GetBlockDescriptions(OperationLogOutput& Output,
}
}
- if (!IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(Output,
- "GetBlockMetadata for {} took {}. Found {} blocks",
- BuildPartId,
- NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()),
- Result.size());
- }
-
if (Result.size() != BlockRawHashes.size())
{
std::string ErrorDescription =
diff --git a/src/zenremotestore/builds/filebuildstorage.cpp b/src/zenremotestore/builds/filebuildstorage.cpp
index 55e69de61..2f4904449 100644
--- a/src/zenremotestore/builds/filebuildstorage.cpp
+++ b/src/zenremotestore/builds/filebuildstorage.cpp
@@ -432,6 +432,45 @@ public:
return IoBuffer{};
}
+ virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId,
+ const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_TRACE_CPU("FileBuildStorage::GetBuildBlobRanges");
+ ZEN_UNUSED(BuildId);
+ ZEN_ASSERT(!Ranges.empty());
+
+ uint64_t ReceivedBytes = 0;
+ uint64_t SentBytes = Ranges.size() * 2 * 8;
+
+ SimulateLatency(SentBytes, 0);
+ auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); });
+
+ Stopwatch ExecutionTimer;
+ auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer, SentBytes, ReceivedBytes); });
+
+ BuildBlobRanges Result;
+
+ const std::filesystem::path BlockPath = GetBlobPayloadPath(RawHash);
+ if (IsFile(BlockPath))
+ {
+ BasicFile File(BlockPath, BasicFile::Mode::kRead);
+
+ uint64_t RangeOffset = Ranges.front().first;
+ uint64_t RangeBytes = Ranges.back().first + Ranges.back().second - RangeOffset;
+ Result.PayloadBuffer = IoBufferBuilder::MakeFromFileHandle(File.Detach(), RangeOffset, RangeBytes);
+
+ Result.Ranges.reserve(Ranges.size());
+
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ Result.Ranges.push_back(std::make_pair(Range.first - RangeOffset, Range.second));
+ }
+ ReceivedBytes = Result.PayloadBuffer.GetSize();
+ }
+ return Result;
+ }
+
virtual std::vector<std::function<void()>> GetLargeBuildBlob(const Oid& BuildId,
const IoHash& RawHash,
uint64_t ChunkSize,
diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp
index 23d0ddd4c..8e16da1a9 100644
--- a/src/zenremotestore/builds/jupiterbuildstorage.cpp
+++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp
@@ -21,7 +21,7 @@ namespace zen {
using namespace std::literals;
namespace {
- void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix)
+ [[noreturn]] void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix)
{
int Error = Result.ErrorCode < (int)HttpResponseCode::Continue ? Result.ErrorCode : 0;
HttpResponseCode Status =
@@ -295,6 +295,26 @@ public:
return std::move(GetBuildBlobResult.Response);
}
+ virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId,
+ const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_TRACE_CPU("Jupiter::GetBuildBlob");
+
+ Stopwatch ExecutionTimer;
+ auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); });
+ CreateDirectories(m_TempFolderPath);
+
+ BuildBlobRangesResult GetBuildBlobResult =
+ m_Session.GetBuildBlob(m_Namespace, m_Bucket, BuildId, RawHash, m_TempFolderPath, Ranges);
+ AddStatistic(GetBuildBlobResult);
+ if (!GetBuildBlobResult.Success)
+ {
+ ThrowFromJupiterResult(GetBuildBlobResult, "Failed fetching build blob ranges"sv);
+ }
+ return BuildBlobRanges{.PayloadBuffer = std::move(GetBuildBlobResult.Response), .Ranges = std::move(GetBuildBlobResult.Ranges)};
+ }
+
virtual std::vector<std::function<void()>> GetLargeBuildBlob(const Oid& BuildId,
const IoHash& RawHash,
uint64_t ChunkSize,
diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp
index c4d8653f4..cca32c17d 100644
--- a/src/zenremotestore/chunking/chunkblock.cpp
+++ b/src/zenremotestore/chunking/chunkblock.cpp
@@ -7,27 +7,201 @@
#include <zencore/logging.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
-
#include <zenremotestore/operationlogoutput.h>
-#include <vector>
+#include <numeric>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <tsl/robin_map.h>
+#include <tsl/robin_set.h>
ZEN_THIRD_PARTY_INCLUDES_END
#if ZEN_WITH_TESTS
# include <zencore/testing.h>
# include <zencore/testutils.h>
-
-# include <unordered_map>
-# include <numeric>
#endif // ZEN_WITH_TESTS
namespace zen {
using namespace std::literals;
+namespace chunkblock_impl {
+
+ struct RangeDescriptor
+ {
+ uint64_t RangeStart = 0;
+ uint64_t RangeLength = 0;
+ uint32_t ChunkBlockIndexStart = 0;
+ uint32_t ChunkBlockIndexCount = 0;
+ };
+
+ void MergeCheapestRange(std::vector<RangeDescriptor>& InOutRanges)
+ {
+ ZEN_ASSERT(InOutRanges.size() > 1);
+
+ size_t BestRangeIndexToCollapse = SIZE_MAX;
+ uint64_t BestGap = (uint64_t)-1;
+
+ for (size_t RangeIndex = 0; RangeIndex < InOutRanges.size() - 1; RangeIndex++)
+ {
+ const RangeDescriptor& Range = InOutRanges[RangeIndex];
+ const RangeDescriptor& NextRange = InOutRanges[RangeIndex + 1];
+ uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength);
+ if (Gap < BestGap)
+ {
+ BestRangeIndexToCollapse = RangeIndex;
+ BestGap = Gap;
+ }
+ else if (Gap == BestGap)
+ {
+ const RangeDescriptor& BestRange = InOutRanges[BestRangeIndexToCollapse];
+ const RangeDescriptor& BestNextRange = InOutRanges[BestRangeIndexToCollapse + 1];
+ uint64_t BestMergedSize = (BestNextRange.RangeStart + BestNextRange.RangeLength) - BestRange.RangeStart;
+ uint64_t MergedSize = (NextRange.RangeStart + NextRange.RangeLength) - Range.RangeStart;
+ if (MergedSize < BestMergedSize)
+ {
+ BestRangeIndexToCollapse = RangeIndex;
+ }
+ }
+ }
+
+ ZEN_ASSERT(BestRangeIndexToCollapse != SIZE_MAX);
+ ZEN_ASSERT(BestRangeIndexToCollapse < InOutRanges.size() - 1);
+ ZEN_ASSERT(BestGap != (uint64_t)-1);
+
+ RangeDescriptor& BestRange = InOutRanges[BestRangeIndexToCollapse];
+ const RangeDescriptor& BestNextRange = InOutRanges[BestRangeIndexToCollapse + 1];
+ BestRange.RangeLength = BestNextRange.RangeStart - BestRange.RangeStart + BestNextRange.RangeLength;
+ BestRange.ChunkBlockIndexCount =
+ BestNextRange.ChunkBlockIndexStart - BestRange.ChunkBlockIndexStart + BestNextRange.ChunkBlockIndexCount;
+ InOutRanges.erase(InOutRanges.begin() + BestRangeIndexToCollapse + 1);
+ }
+
+ std::vector<RangeDescriptor> GetBlockRanges(const ChunkBlockDescription& BlockDescription,
+ const uint64_t ChunkStartOffsetInBlock,
+ std::span<const uint32_t> BlockChunkIndexNeeded)
+ {
+ ZEN_TRACE_CPU("GetBlockRanges");
+ std::vector<RangeDescriptor> BlockRanges;
+ {
+ uint64_t CurrentOffset = ChunkStartOffsetInBlock;
+ uint32_t ChunkBlockIndex = 0;
+ uint32_t NeedBlockChunkIndexOffset = 0;
+ RangeDescriptor NextRange;
+ while (NeedBlockChunkIndexOffset < BlockChunkIndexNeeded.size() && ChunkBlockIndex < BlockDescription.ChunkRawHashes.size())
+ {
+ const uint32_t ChunkCompressedLength = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex];
+ if (ChunkBlockIndex < BlockChunkIndexNeeded[NeedBlockChunkIndexOffset])
+ {
+ if (NextRange.RangeLength > 0)
+ {
+ BlockRanges.push_back(NextRange);
+ NextRange = {};
+ }
+ ChunkBlockIndex++;
+ CurrentOffset += ChunkCompressedLength;
+ }
+ else if (ChunkBlockIndex == BlockChunkIndexNeeded[NeedBlockChunkIndexOffset])
+ {
+ if (NextRange.RangeLength == 0)
+ {
+ NextRange.RangeStart = CurrentOffset;
+ NextRange.ChunkBlockIndexStart = ChunkBlockIndex;
+ }
+ NextRange.RangeLength += ChunkCompressedLength;
+ NextRange.ChunkBlockIndexCount++;
+ ChunkBlockIndex++;
+ CurrentOffset += ChunkCompressedLength;
+ NeedBlockChunkIndexOffset++;
+ }
+ else
+ {
+ ZEN_ASSERT(false);
+ }
+ }
+ if (NextRange.RangeLength > 0)
+ {
+ BlockRanges.push_back(NextRange);
+ }
+ }
+ ZEN_ASSERT(!BlockRanges.empty());
+ return BlockRanges;
+ }
+
+ std::vector<RangeDescriptor> OptimizeRanges(uint64_t TotalBlockSize,
+ std::span<const RangeDescriptor> ExactRanges,
+ double LatencySec,
+ uint64_t SpeedBytesPerSec,
+ uint64_t MaxRangeCountPerRequest,
+ uint64_t MaxRangesPerBlock)
+ {
+ ZEN_TRACE_CPU("OptimizeRanges");
+ ZEN_ASSERT(MaxRangesPerBlock > 0);
+ std::vector<RangeDescriptor> Ranges(ExactRanges.begin(), ExactRanges.end());
+
+ while (Ranges.size() > MaxRangesPerBlock)
+ {
+ MergeCheapestRange(Ranges);
+ }
+
+ while (true)
+ {
+ const std::uint64_t RangeTotalSize =
+ std::accumulate(Ranges.begin(), Ranges.end(), uint64_t(0u), [](uint64_t Current, const RangeDescriptor& Value) {
+ return Current + Value.RangeLength;
+ });
+
+ const size_t RangeCount = Ranges.size();
+ const uint64_t RequestCount =
+ MaxRangeCountPerRequest == (uint64_t)-1 ? 1 : (RangeCount + MaxRangeCountPerRequest - 1) / MaxRangeCountPerRequest;
+ uint64_t RequestTimeAsBytes = uint64_t(SpeedBytesPerSec * RequestCount * LatencySec);
+
+ if (RangeCount == 1)
+ {
+ // Does fetching the full block add less time than the time it takes to complete a single request?
+ if (TotalBlockSize - RangeTotalSize < SpeedBytesPerSec * LatencySec)
+ {
+ const std::uint64_t InitialRangeTotalSize =
+ std::accumulate(ExactRanges.begin(),
+ ExactRanges.end(),
+ uint64_t(0u),
+ [](uint64_t Current, const RangeDescriptor& Value) { return Current + Value.RangeLength; });
+
+ ZEN_DEBUG(
+ "Latency round trip takes as long as receiving the extra redundant bytes - go full block, dropping {} of slack, "
+ "adding {} of bytes to fetch, for block of size {}",
+ NiceBytes(TotalBlockSize - RangeTotalSize),
+ NiceBytes(TotalBlockSize - InitialRangeTotalSize),
+ NiceBytes(TotalBlockSize));
+ return {};
+ }
+ else
+ {
+ return Ranges;
+ }
+ }
+
+ if (RequestTimeAsBytes < (TotalBlockSize - RangeTotalSize))
+ {
+ return Ranges;
+ }
+
+ if (RangeCount == 2)
+ {
+ // Merge to single range
+ Ranges.front().RangeLength = Ranges.back().RangeStart - Ranges.front().RangeStart + Ranges.back().RangeLength;
+ Ranges.front().ChunkBlockIndexCount =
+ Ranges.back().ChunkBlockIndexStart - Ranges.front().ChunkBlockIndexStart + Ranges.back().ChunkBlockIndexCount;
+ Ranges.pop_back();
+ }
+ else
+ {
+ MergeCheapestRange(Ranges);
+ }
+ }
+ }
+
+} // namespace chunkblock_impl
+
ChunkBlockDescription
ParseChunkBlockDescription(const CbObjectView& BlockObject)
{
@@ -455,9 +629,299 @@ FindReuseBlocks(OperationLogOutput& Output,
return FilteredReuseBlockIndexes;
}
+ChunkBlockAnalyser::ChunkBlockAnalyser(OperationLogOutput& LogOutput,
+ std::span<const ChunkBlockDescription> BlockDescriptions,
+ const Options& Options)
+: m_LogOutput(LogOutput)
+, m_BlockDescriptions(BlockDescriptions)
+, m_Options(Options)
+{
+}
+
+std::vector<ChunkBlockAnalyser::NeededBlock>
+ChunkBlockAnalyser::GetNeeded(const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& ChunkHashToChunkIndex,
+ std::function<bool(uint32_t ChunkIndex)>&& NeedsBlockChunk)
+{
+ ZEN_TRACE_CPU("ChunkBlockAnalyser::GetNeeded");
+
+ std::vector<NeededBlock> Result;
+
+ std::vector<bool> ChunkIsNeeded(ChunkHashToChunkIndex.size());
+ for (uint32_t ChunkIndex = 0; ChunkIndex < ChunkHashToChunkIndex.size(); ChunkIndex++)
+ {
+ ChunkIsNeeded[ChunkIndex] = NeedsBlockChunk(ChunkIndex);
+ }
+
+ std::vector<uint64_t> BlockSlack(m_BlockDescriptions.size(), 0u);
+ for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++)
+ {
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+
+ uint64_t BlockUsedSize = 0;
+ uint64_t BlockSize = 0;
+
+ for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++)
+ {
+ const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex];
+ if (auto It = ChunkHashToChunkIndex.find(ChunkHash); It != ChunkHashToChunkIndex.end())
+ {
+ const uint32_t RemoteChunkIndex = It->second;
+ if (ChunkIsNeeded[RemoteChunkIndex])
+ {
+ BlockUsedSize += BlockDescription.ChunkCompressedLengths[ChunkBlockIndex];
+ }
+ }
+ BlockSize += BlockDescription.ChunkCompressedLengths[ChunkBlockIndex];
+ }
+ BlockSlack[BlockIndex] = BlockSize - BlockUsedSize;
+ }
+
+ std::vector<uint32_t> BlockOrder(m_BlockDescriptions.size());
+ std::iota(BlockOrder.begin(), BlockOrder.end(), 0);
+
+ std::sort(BlockOrder.begin(), BlockOrder.end(), [&BlockSlack](uint32_t Lhs, uint32_t Rhs) {
+ return BlockSlack[Lhs] < BlockSlack[Rhs];
+ });
+
+ std::vector<bool> ChunkIsPickedUp(ChunkHashToChunkIndex.size(), false);
+
+ for (uint32_t BlockIndex : BlockOrder)
+ {
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+
+ std::vector<uint32_t> BlockChunkIndexNeeded;
+
+ for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++)
+ {
+ const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex];
+ if (auto It = ChunkHashToChunkIndex.find(ChunkHash); It != ChunkHashToChunkIndex.end())
+ {
+ const uint32_t RemoteChunkIndex = It->second;
+ if (ChunkIsNeeded[RemoteChunkIndex])
+ {
+ if (!ChunkIsPickedUp[RemoteChunkIndex])
+ {
+ ChunkIsPickedUp[RemoteChunkIndex] = true;
+ BlockChunkIndexNeeded.push_back(ChunkBlockIndex);
+ }
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("Chunk {} not found in block {}", ChunkHash, BlockDescription.BlockHash);
+ }
+ }
+
+ if (!BlockChunkIndexNeeded.empty())
+ {
+ Result.push_back(NeededBlock{.BlockIndex = BlockIndex, .ChunkIndexes = std::move(BlockChunkIndexNeeded)});
+ }
+ }
+ return Result;
+}
+
+ChunkBlockAnalyser::BlockResult
+ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span<const NeededBlock> NeededBlocks,
+ std::span<const EPartialBlockDownloadMode> BlockPartialDownloadModes)
+{
+ ZEN_TRACE_CPU("ChunkBlockAnalyser::CalculatePartialBlockDownloads");
+
+ Stopwatch PartialAnalisysTimer;
+
+ ChunkBlockAnalyser::BlockResult Result;
+
+ {
+ uint64_t MinRequestCount = 0;
+ uint64_t RequestCount = 0;
+ uint64_t RangeCount = 0;
+ uint64_t IdealDownloadTotalSize = 0;
+ uint64_t ActualDownloadTotalSize = 0;
+ uint64_t FullDownloadTotalSize = 0;
+ for (const NeededBlock& NeededBlock : NeededBlocks)
+ {
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex];
+ std::span<const uint32_t> BlockChunkIndexNeeded(NeededBlock.ChunkIndexes);
+ const uint32_t ChunkStartOffsetInBlock =
+ gsl::narrow<uint32_t>(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize);
+ uint64_t TotalBlockSize = std::accumulate(BlockDescription.ChunkCompressedLengths.begin(),
+ BlockDescription.ChunkCompressedLengths.end(),
+ uint64_t(ChunkStartOffsetInBlock));
+ uint64_t ExactRangesSize = 0;
+ uint64_t DownloadRangesSize = 0;
+ uint64_t FullDownloadSize = 0;
+
+ bool CanDoPartialBlockDownload = (BlockDescription.HeaderSize > 0) &&
+ (BlockDescription.ChunkCompressedLengths.size() == BlockDescription.ChunkRawHashes.size());
+
+ if (NeededBlock.ChunkIndexes.size() == BlockDescription.ChunkRawHashes.size() || !CanDoPartialBlockDownload)
+ {
+ // Full block
+ ExactRangesSize = TotalBlockSize;
+ DownloadRangesSize = TotalBlockSize;
+ FullDownloadSize = TotalBlockSize;
+ MinRequestCount++;
+ RequestCount++;
+ RangeCount++;
+ Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex);
+ }
+ else if (NeededBlock.ChunkIndexes.empty())
+ {
+ // Not needed
+ }
+ else
+ {
+ FullDownloadSize = TotalBlockSize;
+ std::vector<chunkblock_impl::RangeDescriptor> Ranges =
+ chunkblock_impl::GetBlockRanges(BlockDescription, ChunkStartOffsetInBlock, BlockChunkIndexNeeded);
+ ExactRangesSize = std::accumulate(
+ Ranges.begin(),
+ Ranges.end(),
+ uint64_t(0),
+ [](uint64_t Current, const chunkblock_impl::RangeDescriptor& Range) { return Current + Range.RangeLength; });
+
+ EPartialBlockDownloadMode PartialBlockDownloadMode = BlockPartialDownloadModes[NeededBlock.BlockIndex];
+ if (PartialBlockDownloadMode == EPartialBlockDownloadMode::Off)
+ {
+ // Use full block
+ MinRequestCount++;
+ RangeCount++;
+ RequestCount++;
+ Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex);
+ DownloadRangesSize = TotalBlockSize;
+ }
+ else
+ {
+ const bool IsHighSpeed = (PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed);
+ uint64_t MaxRangeCountPerRequest =
+ IsHighSpeed ? m_Options.HostHighSpeedMaxRangeCountPerRequest : m_Options.HostMaxRangeCountPerRequest;
+ ZEN_ASSERT(MaxRangeCountPerRequest != 0);
+
+ if (PartialBlockDownloadMode == EPartialBlockDownloadMode::Exact)
+ {
+ // Use exact ranges
+ for (const chunkblock_impl::RangeDescriptor& Range : Ranges)
+ {
+ Result.BlockRanges.push_back(BlockRangeDescriptor{.BlockIndex = NeededBlock.BlockIndex,
+ .RangeStart = Range.RangeStart,
+ .RangeLength = Range.RangeLength,
+ .ChunkBlockIndexStart = Range.ChunkBlockIndexStart,
+ .ChunkBlockIndexCount = Range.ChunkBlockIndexCount});
+ }
+
+ MinRequestCount++;
+ RangeCount += Ranges.size();
+ RequestCount += MaxRangeCountPerRequest == (uint64_t)-1
+ ? 1
+ : (Ranges.size() + MaxRangeCountPerRequest - 1) / MaxRangeCountPerRequest;
+ DownloadRangesSize = ExactRangesSize;
+ }
+ else
+ {
+ if (PartialBlockDownloadMode == EPartialBlockDownloadMode::SingleRange)
+ {
+ // Use single range
+ if (Ranges.size() > 1)
+ {
+ Ranges = {chunkblock_impl::RangeDescriptor{
+ .RangeStart = Ranges.front().RangeStart,
+ .RangeLength = Ranges.back().RangeStart + Ranges.back().RangeLength - Ranges.front().RangeStart,
+ .ChunkBlockIndexStart = Ranges.front().ChunkBlockIndexStart,
+ .ChunkBlockIndexCount = Ranges.back().ChunkBlockIndexStart + Ranges.back().ChunkBlockIndexCount -
+ Ranges.front().ChunkBlockIndexStart}};
+ }
+
+ // We still do the optimize pass to see if it is more effective to use a full block
+ }
+
+ double LatencySec = IsHighSpeed ? m_Options.HostHighSpeedLatencySec : m_Options.HostLatencySec;
+ uint64_t SpeedBytesPerSec = IsHighSpeed ? m_Options.HostHighSpeedBytesPerSec : m_Options.HostSpeedBytesPerSec;
+ if (LatencySec > 0.0 && SpeedBytesPerSec > 0u)
+ {
+ Ranges = chunkblock_impl::OptimizeRanges(TotalBlockSize,
+ Ranges,
+ LatencySec,
+ SpeedBytesPerSec,
+ MaxRangeCountPerRequest,
+ m_Options.MaxRangesPerBlock);
+ }
+
+ MinRequestCount++;
+ if (Ranges.empty())
+ {
+ Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex);
+ RequestCount++;
+ RangeCount++;
+ DownloadRangesSize = TotalBlockSize;
+ }
+ else
+ {
+ for (const chunkblock_impl::RangeDescriptor& Range : Ranges)
+ {
+ Result.BlockRanges.push_back(BlockRangeDescriptor{.BlockIndex = NeededBlock.BlockIndex,
+ .RangeStart = Range.RangeStart,
+ .RangeLength = Range.RangeLength,
+ .ChunkBlockIndexStart = Range.ChunkBlockIndexStart,
+ .ChunkBlockIndexCount = Range.ChunkBlockIndexCount});
+ }
+ RangeCount += Ranges.size();
+ RequestCount += MaxRangeCountPerRequest == (uint64_t)-1
+ ? 1
+ : (Ranges.size() + MaxRangeCountPerRequest - 1) / MaxRangeCountPerRequest;
+ }
+
+ DownloadRangesSize = Ranges.empty()
+ ? TotalBlockSize
+ : std::accumulate(Ranges.begin(),
+ Ranges.end(),
+ uint64_t(0),
+ [](uint64_t Current, const chunkblock_impl::RangeDescriptor& Range) {
+ return Current + Range.RangeLength;
+ });
+ }
+ }
+ }
+ IdealDownloadTotalSize += ExactRangesSize;
+ ActualDownloadTotalSize += DownloadRangesSize;
+ FullDownloadTotalSize += FullDownloadSize;
+
+ if (ExactRangesSize < FullDownloadSize)
+ {
+ ZEN_DEBUG("Block {}: Full: {}, Ideal: {}, Actual: {}, Saves: {}",
+ NeededBlock.BlockIndex,
+ NiceBytes(FullDownloadSize),
+ NiceBytes(ExactRangesSize),
+ NiceBytes(DownloadRangesSize),
+ NiceBytes(FullDownloadSize - DownloadRangesSize));
+ }
+ }
+ uint64_t Actual = FullDownloadTotalSize - ActualDownloadTotalSize;
+ uint64_t Ideal = FullDownloadTotalSize - IdealDownloadTotalSize;
+ if (Ideal < FullDownloadTotalSize && !m_Options.IsQuiet)
+ {
+ const double AchievedPercent = Ideal == 0 ? 100.0 : (100.0 * Actual) / Ideal;
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Block Partial Analysis: Blocks: {}, Full: {}, Ideal: {}, Actual: {}. Skipping {} ({:.1f}%) out of "
+ "possible {} using {} extra ranges "
+ "via {} extra requests. Completed in {}",
+ NeededBlocks.size(),
+ NiceBytes(FullDownloadTotalSize),
+ NiceBytes(IdealDownloadTotalSize),
+ NiceBytes(ActualDownloadTotalSize),
+ NiceBytes(FullDownloadTotalSize - ActualDownloadTotalSize),
+ AchievedPercent,
+ NiceBytes(Ideal),
+ RangeCount - MinRequestCount,
+ RequestCount - MinRequestCount,
+ NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs()));
+ }
+ }
+
+ return Result;
+}
+
#if ZEN_WITH_TESTS
-namespace testutils {
+namespace chunkblock_testutils {
static std::vector<std::pair<Oid, CompressedBuffer>> CreateAttachments(
const std::span<const size_t>& Sizes,
OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast,
@@ -474,12 +938,14 @@ namespace testutils {
return Result;
}
-} // namespace testutils
+} // namespace chunkblock_testutils
+
+TEST_SUITE_BEGIN("remotestore.chunkblock");
-TEST_CASE("project.store.block")
+TEST_CASE("chunkblock.block")
{
using namespace std::literals;
- using namespace testutils;
+ using namespace chunkblock_testutils;
std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489,
7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251,
@@ -504,10 +970,10 @@ TEST_CASE("project.store.block")
HeaderSize));
}
-TEST_CASE("project.store.reuseblocks")
+TEST_CASE("chunkblock.reuseblocks")
{
using namespace std::literals;
- using namespace testutils;
+ using namespace chunkblock_testutils;
std::vector<std::vector<std::size_t>> BlockAttachmentSizes(
{std::vector<std::size_t>{7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489,
@@ -744,6 +1210,894 @@ TEST_CASE("project.store.reuseblocks")
}
}
+namespace chunkblock_analyser_testutils {
+
+ // Build a ChunkBlockDescription without any real payload.
+ // Hashes are derived deterministically from (BlockSeed XOR ChunkIndex) so that the same
+ // seed produces the same hashes — useful for deduplication tests.
+ static ChunkBlockDescription MakeBlockDesc(uint64_t HeaderSize,
+ std::initializer_list<uint32_t> CompressedLengths,
+ uint32_t BlockSeed = 0)
+ {
+ ChunkBlockDescription Desc;
+ Desc.HeaderSize = HeaderSize;
+ uint32_t ChunkIndex = 0;
+ for (uint32_t Length : CompressedLengths)
+ {
+ uint64_t HashInput = uint64_t(BlockSeed ^ ChunkIndex);
+ Desc.ChunkRawHashes.push_back(IoHash::HashBuffer(MemoryView(&HashInput, sizeof(HashInput))));
+ Desc.ChunkRawLengths.push_back(Length);
+ Desc.ChunkCompressedLengths.push_back(Length);
+ ChunkIndex++;
+ }
+ return Desc;
+ }
+
+ // Build the robin_map<IoHash, uint32_t> needed by GetNeeded from a flat list of blocks.
+ // First occurrence of each hash wins; index is assigned sequentially across all blocks.
+ [[maybe_unused]] static tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> MakeHashMap(const std::vector<ChunkBlockDescription>& Blocks)
+ {
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> Result;
+ uint32_t Index = 0;
+ for (const ChunkBlockDescription& Block : Blocks)
+ {
+ for (const IoHash& Hash : Block.ChunkRawHashes)
+ {
+ if (!Result.contains(Hash))
+ {
+ Result.emplace(Hash, Index++);
+ }
+ }
+ }
+ return Result;
+ }
+
+} // namespace chunkblock_analyser_testutils
+
+TEST_CASE("chunkblock.mergecheapestrange.picks_smallest_gap")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // Gap between ranges 0-1 is 50, gap between 1-2 is 150 → pair 0-1 gets merged
+ std::vector<RD> Ranges = {
+ {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 150, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1},
+ };
+ chunkblock_impl::MergeCheapestRange(Ranges);
+
+ REQUIRE_EQ(2u, Ranges.size());
+ CHECK_EQ(0u, Ranges[0].RangeStart);
+ CHECK_EQ(250u, Ranges[0].RangeLength); // 150+100
+ CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(2u, Ranges[0].ChunkBlockIndexCount);
+ CHECK_EQ(400u, Ranges[1].RangeStart);
+ CHECK_EQ(100u, Ranges[1].RangeLength);
+ CHECK_EQ(2u, Ranges[1].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[1].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.mergecheapestrange.tiebreak_smaller_merged")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // Gap 0-1 == gap 1-2 == 100; merged size 0-1 (250) < merged size 1-2 (350) → pair 0-1 wins
+ std::vector<RD> Ranges = {
+ {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 350, .RangeLength = 200, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1},
+ };
+ chunkblock_impl::MergeCheapestRange(Ranges);
+
+ REQUIRE_EQ(2u, Ranges.size());
+ // Pair 0-1 merged: start=0, length = (200+50)-0 = 250
+ CHECK_EQ(0u, Ranges[0].RangeStart);
+ CHECK_EQ(250u, Ranges[0].RangeLength);
+ CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(2u, Ranges[0].ChunkBlockIndexCount);
+ // Pair 1 unchanged (was index 2)
+ CHECK_EQ(350u, Ranges[1].RangeStart);
+ CHECK_EQ(200u, Ranges[1].RangeLength);
+ CHECK_EQ(2u, Ranges[1].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[1].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.optimizeranges.preserves_ranges_low_latency")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // With MaxRangeCountPerRequest unlimited, RequestCount=1
+ // RequestTimeAsBytes = 100000 * 1 * 0.001 = 100 << slack=7000 → all ranges preserved
+ std::vector<RD> ExactRanges = {
+ {.RangeStart = 0, .RangeLength = 1000, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 2000, .RangeLength = 1000, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 4000, .RangeLength = 1000, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1},
+ };
+ uint64_t TotalBlockSize = 10000;
+ double LatencySec = 0.001;
+ uint64_t SpeedBytesPerSec = 100000;
+ uint64_t MaxRangeCountPerReq = (uint64_t)-1;
+ uint64_t MaxRangesPerBlock = 1024;
+
+ auto Result =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock);
+
+ REQUIRE_EQ(3u, Result.size());
+}
+
+TEST_CASE("chunkblock.optimizeranges.falls_back_to_full_block")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // 1 range already; slack=100 < SpeedBytesPerSec*LatencySec=200 → full block (empty result)
+ std::vector<RD> ExactRanges = {
+ {.RangeStart = 100, .RangeLength = 900, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 3},
+ };
+ uint64_t TotalBlockSize = 1000;
+ double LatencySec = 0.01;
+ uint64_t SpeedBytesPerSec = 20000;
+ uint64_t MaxRangeCountPerReq = (uint64_t)-1;
+ uint64_t MaxRangesPerBlock = 1024;
+
+ auto Result =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock);
+
+ CHECK(Result.empty());
+}
+
+TEST_CASE("chunkblock.optimizeranges.maxrangesperblock_clamp")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // 5 input ranges; MaxRangesPerBlock=2 clamps to ≤2 before the cost model runs
+ std::vector<RD> ExactRanges = {
+ {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 300, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 600, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 900, .RangeLength = 100, .ChunkBlockIndexStart = 3, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 1200, .RangeLength = 100, .ChunkBlockIndexStart = 4, .ChunkBlockIndexCount = 1},
+ };
+ uint64_t TotalBlockSize = 5000;
+ double LatencySec = 0.001;
+ uint64_t SpeedBytesPerSec = 100000;
+ uint64_t MaxRangeCountPerReq = (uint64_t)-1;
+ uint64_t MaxRangesPerBlock = 2;
+
+ auto Result =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock);
+
+ CHECK(Result.size() <= 2u);
+ CHECK(!Result.empty());
+}
+
+TEST_CASE("chunkblock.optimizeranges.low_maxrangecountperrequest_drives_merge")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // MaxRangeCountPerRequest=1 means RequestCount==RangeCount; high latency drives merging
+ // With MaxRangeCountPerRequest=-1 the same 3 ranges would be preserved (verified by comment below)
+ std::vector<RD> ExactRanges = {
+ {.RangeStart = 100, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 250, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1},
+ };
+ uint64_t TotalBlockSize = 1000;
+ double LatencySec = 1.0;
+ uint64_t SpeedBytesPerSec = 500;
+ // With MaxRangeCountPerRequest=-1: RequestCount=1, RequestTimeAsBytes=500 < slack=700 → preserved
+ // With MaxRangeCountPerRequest=1: RequestCount=3, RequestTimeAsBytes=1500 > slack=700 → merged
+ uint64_t MaxRangesPerBlock = 1024;
+
+ auto Unlimited =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, (uint64_t)-1, MaxRangesPerBlock);
+ CHECK_EQ(3u, Unlimited.size());
+
+ auto Limited =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, uint64_t(1), MaxRangesPerBlock);
+ CHECK(Limited.size() < 3u);
+}
+
+TEST_CASE("chunkblock.optimizeranges.unlimited_rangecountperrequest_no_extra_cost")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // MaxRangeCountPerRequest=-1 → RequestCount always 1, even with many ranges and high latency
+ std::vector<RD> ExactRanges = {
+ {.RangeStart = 0, .RangeLength = 50, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 400, .RangeLength = 50, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 600, .RangeLength = 50, .ChunkBlockIndexStart = 3, .ChunkBlockIndexCount = 1},
+ {.RangeStart = 800, .RangeLength = 50, .ChunkBlockIndexStart = 4, .ChunkBlockIndexCount = 1},
+ };
+ uint64_t TotalBlockSize = 5000;
+ double LatencySec = 0.1;
+ uint64_t SpeedBytesPerSec = 10000; // RequestTimeAsBytes=1000 << slack=4750
+ uint64_t MaxRangeCountPerReq = (uint64_t)-1;
+ uint64_t MaxRangesPerBlock = 1024;
+
+ auto Result =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock);
+
+ CHECK_EQ(5u, Result.size());
+}
+
+TEST_CASE("chunkblock.optimizeranges.two_range_direct_merge_path")
+{
+ using RD = chunkblock_impl::RangeDescriptor;
+ // Exactly 2 ranges; cost model demands merge; exercises the RangeCount==2 direct-merge branch
+ // After direct merge → 1 range with small slack → full block (empty)
+ std::vector<RD> ExactRanges = {
+ {.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 2},
+ {.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 2},
+ };
+ uint64_t TotalBlockSize = 600;
+ double LatencySec = 0.1;
+ uint64_t SpeedBytesPerSec = 5000; // RequestTimeAsBytes=500 > slack=400 on first iter
+ uint64_t MaxRangeCountPerReq = (uint64_t)-1;
+ uint64_t MaxRangesPerBlock = 1024;
+
+ // Iteration 1: RangeCount=2, RequestCount=1, RequestTimeAsBytes=500 > slack=400 → direct merge
+ // After merge: 1 range [{0,500,0,4}], slack=100 < Speed*Lat=500 → full block
+ auto Result =
+ chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock);
+
+ CHECK(Result.empty());
+}
+
+TEST_CASE("chunkblock.getneeded.all_chunks")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ auto HashMap = MakeHashMap({Block});
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return true; });
+
+ REQUIRE_EQ(1u, NeededBlocks.size());
+ CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
+ REQUIRE_EQ(4u, NeededBlocks[0].ChunkIndexes.size());
+ CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]);
+ CHECK_EQ(1u, NeededBlocks[0].ChunkIndexes[1]);
+ CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[2]);
+ CHECK_EQ(3u, NeededBlocks[0].ChunkIndexes[3]);
+}
+
+TEST_CASE("chunkblock.getneeded.no_chunks")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ auto HashMap = MakeHashMap({Block});
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return false; });
+
+ CHECK(NeededBlocks.empty());
+}
+
+TEST_CASE("chunkblock.getneeded.subset_within_block")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ auto HashMap = MakeHashMap({Block});
+ // Indices 0 and 2 are needed; 1 and 3 are not
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex == 0 || ChunkIndex == 2; });
+
+ REQUIRE_EQ(1u, NeededBlocks.size());
+ CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
+ REQUIRE_EQ(2u, NeededBlocks[0].ChunkIndexes.size());
+ CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]);
+ CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[1]);
+}
+
+TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // Block 0: {H0, H1, SharedH, H3} — 3 of 4 needed (H3 not needed); slack = 100
+ // Block 1: {H4, H5, SharedH, H6} — only SharedH needed; slack = 300
+ // Block 0 has less slack → processed first → SharedH assigned to block 0
+ IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_dedup", 18));
+ IoHash H0 = IoHash::HashBuffer(MemoryView("block0_chunk0", 13));
+ IoHash H1 = IoHash::HashBuffer(MemoryView("block0_chunk1", 13));
+ IoHash H3 = IoHash::HashBuffer(MemoryView("block0_chunk3", 13));
+ IoHash H4 = IoHash::HashBuffer(MemoryView("block1_chunk0", 13));
+ IoHash H5 = IoHash::HashBuffer(MemoryView("block1_chunk1", 13));
+ IoHash H6 = IoHash::HashBuffer(MemoryView("block1_chunk3", 13));
+
+ ChunkBlockDescription Block0;
+ Block0.HeaderSize = 50;
+ Block0.ChunkRawHashes = {H0, H1, SharedH, H3};
+ Block0.ChunkRawLengths = {100, 100, 100, 100};
+ Block0.ChunkCompressedLengths = {100, 100, 100, 100};
+
+ ChunkBlockDescription Block1;
+ Block1.HeaderSize = 50;
+ Block1.ChunkRawHashes = {H4, H5, SharedH, H6};
+ Block1.ChunkRawLengths = {100, 100, 100, 100};
+ Block1.ChunkCompressedLengths = {100, 100, 100, 100};
+
+ std::vector<ChunkBlockDescription> Blocks = {Block0, Block1};
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+
+ // Map: H0→0, H1→1, SharedH→2, H3→3, H4→4, H5→5, H6→6
+ auto HashMap = MakeHashMap(Blocks);
+ // Need H0(0), H1(1), SharedH(2) from block 0; SharedH from block 1 (already index 2)
+ // H3(3) not needed; H4,H5,H6 not needed
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex <= 2; });
+
+ // Block 0 slack=100 (H3 unused), block 1 slack=300 (H4,H5,H6 unused)
+ // Block 0 processed first; picks up H0, H1, SharedH
+ // Block 1 tries SharedH but it's already picked up → empty → not added
+ REQUIRE_EQ(1u, NeededBlocks.size());
+ CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
+ REQUIRE_EQ(3u, NeededBlocks[0].ChunkIndexes.size());
+ CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]);
+ CHECK_EQ(1u, NeededBlocks[0].ChunkIndexes[1]);
+ CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[2]);
+}
+
+TEST_CASE("chunkblock.getneeded.dedup_no_double_pickup")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // SharedH appears in both blocks; should appear in the result exactly once
+ IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_nodup", 18));
+ IoHash H0 = IoHash::HashBuffer(MemoryView("unique_chunk_b0", 15));
+ IoHash H1 = IoHash::HashBuffer(MemoryView("unique_chunk_b1a", 16));
+ IoHash H2 = IoHash::HashBuffer(MemoryView("unique_chunk_b1b", 16));
+ IoHash H3 = IoHash::HashBuffer(MemoryView("unique_chunk_b1c", 16));
+
+ ChunkBlockDescription Block0;
+ Block0.HeaderSize = 50;
+ Block0.ChunkRawHashes = {SharedH, H0};
+ Block0.ChunkRawLengths = {100, 100};
+ Block0.ChunkCompressedLengths = {100, 100};
+
+ ChunkBlockDescription Block1;
+ Block1.HeaderSize = 50;
+ Block1.ChunkRawHashes = {H1, H2, H3, SharedH};
+ Block1.ChunkRawLengths = {100, 100, 100, 100};
+ Block1.ChunkCompressedLengths = {100, 100, 100, 100};
+
+ std::vector<ChunkBlockDescription> Blocks = {Block0, Block1};
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+
+ // Map: SharedH→0, H0→1, H1→2, H2→3, H3→4
+ // Only SharedH (index 0) needed; no other chunks
+ auto HashMap = MakeHashMap(Blocks);
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex == 0; });
+
+ // Block 0: SharedH needed, H0 not needed → slack=100
+ // Block 1: SharedH needed, H1/H2/H3 not needed → slack=300
+ // Block 0 processed first → picks up SharedH; Block 1 skips it
+
+ // Count total occurrences of SharedH across all NeededBlocks
+ uint32_t SharedOccurrences = 0;
+ for (const auto& NB : NeededBlocks)
+ {
+ for (uint32_t Idx : NB.ChunkIndexes)
+ {
+ // SharedH is at block-local index 0 in Block0 and index 3 in Block1
+ (void)Idx;
+ SharedOccurrences++;
+ }
+ }
+ CHECK_EQ(1u, SharedOccurrences);
+ REQUIRE_EQ(1u, NeededBlocks.size());
+ CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
+}
+
+TEST_CASE("chunkblock.getneeded.skips_unrequested_chunks")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // Block has 4 chunks but only 2 appear in the hash map → ChunkIndexes has exactly those 2
+ auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ // Only put chunks at positions 0 and 2 in the map
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> HashMap;
+ HashMap.emplace(Block.ChunkRawHashes[0], 0u);
+ HashMap.emplace(Block.ChunkRawHashes[2], 1u);
+
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return true; });
+
+ REQUIRE_EQ(1u, NeededBlocks.size());
+ CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
+ REQUIRE_EQ(2u, NeededBlocks[0].ChunkIndexes.size());
+ CHECK_EQ(0u, NeededBlocks[0].ChunkIndexes[0]);
+ CHECK_EQ(2u, NeededBlocks[0].ChunkIndexes[1]);
+}
+
+TEST_CASE("chunkblock.getneeded.two_blocks_both_contribute")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // Block 0: all 4 needed (slack=0); block 1: 3 of 4 needed (slack=100)
+ // Both blocks contribute chunks → 2 NeededBlocks in result
+ auto Block0 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/0);
+ auto Block1 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/200);
+
+ std::vector<ChunkBlockDescription> Blocks = {Block0, Block1};
+ ChunkBlockAnalyser::Options Options;
+ ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+
+ // HashMap: Block0 hashes → indices 0-3, Block1 hashes → indices 4-7
+ auto HashMap = MakeHashMap(Blocks);
+ // Need all Block0 chunks (0-3) and Block1 chunks 0-2 (indices 4-6); not chunk index 7 (Block1 chunk 3)
+ auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex <= 6; });
+
+ CHECK_EQ(2u, NeededBlocks.size());
+ // Block 0 has slack=0 (all 4 needed), Block 1 has slack=100 (1 not needed)
+ // Block 0 comes first in result
+ CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
+ CHECK_EQ(4u, NeededBlocks[0].ChunkIndexes.size());
+ CHECK_EQ(1u, NeededBlocks[1].BlockIndex);
+ CHECK_EQ(3u, NeededBlocks[1].ChunkIndexes.size());
+}
+
+TEST_CASE("chunkblock.calc.off_mode")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // HeaderSize > 0, chunks size matches → CanDoPartialBlockDownload = true
+ // But mode Off forces full block regardless
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::Off};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
+ CHECK_EQ(0u, Result.FullBlockIndexes[0]);
+ CHECK(Result.BlockRanges.empty());
+}
+
+TEST_CASE("chunkblock.calc.exact_mode")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ // Need chunks 0 and 2 → 2 non-contiguous ranges; Exact mode passes them straight through
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::Exact};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ CHECK(Result.FullBlockIndexes.empty());
+ REQUIRE_EQ(2u, Result.BlockRanges.size());
+
+ CHECK_EQ(0u, Result.BlockRanges[0].BlockIndex);
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
+ CHECK_EQ(100u, Result.BlockRanges[0].RangeLength);
+ CHECK_EQ(0u, Result.BlockRanges[0].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Result.BlockRanges[0].ChunkBlockIndexCount);
+
+ CHECK_EQ(0u, Result.BlockRanges[1].BlockIndex);
+ CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart); // 100+200 before chunk 2
+ CHECK_EQ(300u, Result.BlockRanges[1].RangeLength);
+ CHECK_EQ(2u, Result.BlockRanges[1].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Result.BlockRanges[1].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.calc.singlerange_mode")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ // Default HostLatencySec=-1 → OptimizeRanges not called after SingleRange collapse
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ // Need chunks 0 and 2 → 2 ranges that get collapsed to 1
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::SingleRange};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ CHECK(Result.FullBlockIndexes.empty());
+ REQUIRE_EQ(1u, Result.BlockRanges.size());
+ CHECK_EQ(0u, Result.BlockRanges[0].BlockIndex);
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
+ // Spans from chunk 0 start to chunk 2 end: 100+200+300=600
+ CHECK_EQ(600u, Result.BlockRanges[0].RangeLength);
+ CHECK_EQ(0u, Result.BlockRanges[0].ChunkBlockIndexStart);
+ // ChunkBlockIndexCount = (2+1) - 0 = 3
+ CHECK_EQ(3u, Result.BlockRanges[0].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.calc.multirange_mode")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ // Low latency: RequestTimeAsBytes=100 << slack → OptimizeRanges preserves ranges
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ Options.HostLatencySec = 0.001;
+ Options.HostSpeedBytesPerSec = 100000;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::MultiRange};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ CHECK(Result.FullBlockIndexes.empty());
+ REQUIRE_EQ(2u, Result.BlockRanges.size());
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
+ CHECK_EQ(100u, Result.BlockRanges[0].RangeLength);
+ CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart);
+ CHECK_EQ(300u, Result.BlockRanges[1].RangeLength);
+}
+
+TEST_CASE("chunkblock.calc.multirangehighspeed_mode")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ // Block slack ≈ 714 bytes (TotalBlockSize≈1114, RangeTotalSize=400 for chunks 0+2)
+ // RequestTimeAsBytes = 400000 * 1 * 0.001 = 400 < 714 → ranges preserved
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ Options.HostHighSpeedLatencySec = 0.001;
+ Options.HostHighSpeedBytesPerSec = 400000;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::MultiRangeHighSpeed};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ CHECK(Result.FullBlockIndexes.empty());
+ REQUIRE_EQ(2u, Result.BlockRanges.size());
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
+ CHECK_EQ(100u, Result.BlockRanges[0].RangeLength);
+ CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart);
+ CHECK_EQ(300u, Result.BlockRanges[1].RangeLength);
+}
+
+TEST_CASE("chunkblock.calc.all_chunks_needed_full_block")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ Options.HostLatencySec = 0.001;
+ Options.HostSpeedBytesPerSec = 100000;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ // All 4 chunks needed → short-circuit to full block regardless of mode
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 1, 2, 3}}};
+ std::vector<Mode> Modes = {Mode::Exact};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
+ CHECK_EQ(0u, Result.FullBlockIndexes[0]);
+ CHECK(Result.BlockRanges.empty());
+}
+
+TEST_CASE("chunkblock.calc.headersize_zero_forces_full_block")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // HeaderSize=0 → CanDoPartialBlockDownload=false → full block even in Exact mode
+ auto Block = MakeBlockDesc(0, {100, 200, 300, 400});
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::Exact};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
+ CHECK_EQ(0u, Result.FullBlockIndexes[0]);
+ CHECK(Result.BlockRanges.empty());
+}
+
+TEST_CASE("chunkblock.calc.low_maxrangecountperrequest")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // 5 chunks of 100 bytes each; need chunks 0, 2, 4 → 3 non-contiguous ranges
+ // With MaxRangeCountPerRequest=1 and high latency, cost model merges aggressively → full block
+ auto Block = MakeBlockDesc(10, {100, 100, 100, 100, 100});
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ Options.HostLatencySec = 0.1;
+ Options.HostSpeedBytesPerSec = 1000;
+ Options.HostMaxRangeCountPerRequest = 1;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2, 4}}};
+ std::vector<Mode> Modes = {Mode::MultiRange};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ // Cost model drives merging: 3 requests × 1000 × 0.1 = 300 > slack ≈ 210+headersize
+ // After merges converges to full block
+ REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
+ CHECK_EQ(0u, Result.FullBlockIndexes[0]);
+ CHECK(Result.BlockRanges.empty());
+}
+
+TEST_CASE("chunkblock.calc.no_latency_skips_optimize")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ // Default HostLatencySec=-1 → OptimizeRanges not called; raw GetBlockRanges result used
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
+ std::vector<Mode> Modes = {Mode::MultiRange};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ // No optimize pass → exact ranges from GetBlockRanges
+ CHECK(Result.FullBlockIndexes.empty());
+ REQUIRE_EQ(2u, Result.BlockRanges.size());
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
+ CHECK_EQ(100u, Result.BlockRanges[0].RangeLength);
+ CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart);
+ CHECK_EQ(300u, Result.BlockRanges[1].RangeLength);
+}
+
+TEST_CASE("chunkblock.calc.multiple_blocks_different_modes")
+{
+ using namespace chunkblock_analyser_testutils;
+ using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
+
+ LoggerRef LogRef = Log();
+ std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+
+ // 3 blocks with different modes: Off, Exact, MultiRange
+ auto Block0 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/0);
+ auto Block1 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/10);
+ auto Block2 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/20);
+
+ ChunkBlockAnalyser::Options Options;
+ Options.IsQuiet = true;
+ Options.HostLatencySec = 0.001;
+ Options.HostSpeedBytesPerSec = 100000;
+
+ std::vector<ChunkBlockDescription> Blocks = {Block0, Block1, Block2};
+ ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + 50;
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {
+ {.BlockIndex = 0, .ChunkIndexes = {0, 2}},
+ {.BlockIndex = 1, .ChunkIndexes = {0, 2}},
+ {.BlockIndex = 2, .ChunkIndexes = {0, 2}},
+ };
+ std::vector<Mode> Modes = {Mode::Off, Mode::Exact, Mode::MultiRange};
+
+ auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
+
+ // Block 0: Off → FullBlockIndexes
+ REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
+ CHECK_EQ(0u, Result.FullBlockIndexes[0]);
+
+ // Block 1: Exact → 2 ranges; Block 2: MultiRange (low latency) → 2 ranges
+ // Total: 4 ranges
+ REQUIRE_EQ(4u, Result.BlockRanges.size());
+
+ // First 2 ranges belong to Block 1 (Exact)
+ CHECK_EQ(1u, Result.BlockRanges[0].BlockIndex);
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
+ CHECK_EQ(100u, Result.BlockRanges[0].RangeLength);
+ CHECK_EQ(1u, Result.BlockRanges[1].BlockIndex);
+ CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[1].RangeStart);
+ CHECK_EQ(300u, Result.BlockRanges[1].RangeLength);
+
+ // Last 2 ranges belong to Block 2 (MultiRange preserved)
+ CHECK_EQ(2u, Result.BlockRanges[2].BlockIndex);
+ CHECK_EQ(ChunkStartOffset, Result.BlockRanges[2].RangeStart);
+ CHECK_EQ(100u, Result.BlockRanges[2].RangeLength);
+ CHECK_EQ(2u, Result.BlockRanges[3].BlockIndex);
+ CHECK_EQ(ChunkStartOffset + 300u, Result.BlockRanges[3].RangeStart);
+ CHECK_EQ(300u, Result.BlockRanges[3].RangeLength);
+}
+
+TEST_CASE("chunkblock.getblockranges.first_chunk_only")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<uint32_t> Needed = {0};
+ auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed);
+
+ REQUIRE_EQ(1u, Ranges.size());
+ CHECK_EQ(ChunkStartOffset, Ranges[0].RangeStart);
+ CHECK_EQ(100u, Ranges[0].RangeLength);
+ CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.getblockranges.last_chunk_only")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<uint32_t> Needed = {3};
+ auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed);
+
+ REQUIRE_EQ(1u, Ranges.size());
+ CHECK_EQ(ChunkStartOffset + 600u, Ranges[0].RangeStart); // 100+200+300 before chunk 3
+ CHECK_EQ(400u, Ranges[0].RangeLength);
+ CHECK_EQ(3u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.getblockranges.middle_chunk_only")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<uint32_t> Needed = {1};
+ auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed);
+
+ REQUIRE_EQ(1u, Ranges.size());
+ CHECK_EQ(ChunkStartOffset + 100u, Ranges[0].RangeStart); // 100 before chunk 1
+ CHECK_EQ(200u, Ranges[0].RangeLength);
+ CHECK_EQ(1u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.getblockranges.all_chunks")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<uint32_t> Needed = {0, 1, 2, 3};
+ auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed);
+
+ REQUIRE_EQ(1u, Ranges.size());
+ CHECK_EQ(ChunkStartOffset, Ranges[0].RangeStart);
+ CHECK_EQ(1000u, Ranges[0].RangeLength); // 100+200+300+400
+ CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(4u, Ranges[0].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.getblockranges.non_contiguous")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ // Chunks 0 and 2 needed, chunk 1 skipped → two separate ranges
+ auto Block = MakeBlockDesc(50, {100, 200, 300});
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<uint32_t> Needed = {0, 2};
+ auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed);
+
+ REQUIRE_EQ(2u, Ranges.size());
+
+ CHECK_EQ(ChunkStartOffset, Ranges[0].RangeStart);
+ CHECK_EQ(100u, Ranges[0].RangeLength);
+ CHECK_EQ(0u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[0].ChunkBlockIndexCount);
+
+ CHECK_EQ(ChunkStartOffset + 300u, Ranges[1].RangeStart); // 100+200 before chunk 2
+ CHECK_EQ(300u, Ranges[1].RangeLength);
+ CHECK_EQ(2u, Ranges[1].ChunkBlockIndexStart);
+ CHECK_EQ(1u, Ranges[1].ChunkBlockIndexCount);
+}
+
+TEST_CASE("chunkblock.getblockranges.contiguous_run")
+{
+ using namespace chunkblock_analyser_testutils;
+
+ // Chunks 1, 2, 3 needed (consecutive) → one merged range
+ auto Block = MakeBlockDesc(50, {50, 100, 150, 200, 250});
+ uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
+
+ std::vector<uint32_t> Needed = {1, 2, 3};
+ auto Ranges = chunkblock_impl::GetBlockRanges(Block, ChunkStartOffset, Needed);
+
+ REQUIRE_EQ(1u, Ranges.size());
+ CHECK_EQ(ChunkStartOffset + 50u, Ranges[0].RangeStart); // 50 before chunk 1
+ CHECK_EQ(450u, Ranges[0].RangeLength); // 100+150+200
+ CHECK_EQ(1u, Ranges[0].ChunkBlockIndexStart);
+ CHECK_EQ(3u, Ranges[0].ChunkBlockIndexCount);
+}
+
+TEST_SUITE_END();
+
void
chunkblock_forcelink()
{
diff --git a/src/zenremotestore/chunking/chunkedcontent.cpp b/src/zenremotestore/chunking/chunkedcontent.cpp
index 26d179f14..c09ab9d3a 100644
--- a/src/zenremotestore/chunking/chunkedcontent.cpp
+++ b/src/zenremotestore/chunking/chunkedcontent.cpp
@@ -166,7 +166,6 @@ namespace {
if (Chunked.Info.ChunkSequence.empty())
{
AddChunkSequence(Stats, OutChunkedContent.ChunkedContent, ChunkHashToChunkIndex, Chunked.Info.RawHash, RawSize);
- Stats.UniqueSequencesFound++;
}
else
{
@@ -186,7 +185,6 @@ namespace {
Chunked.Info.ChunkHashes,
ChunkSizes);
}
- Stats.UniqueSequencesFound++;
}
});
Stats.FilesChunked++;
@@ -253,7 +251,7 @@ FolderContent::operator==(const FolderContent& Rhs) const
if ((Platform == Rhs.Platform) && (RawSizes == Rhs.RawSizes) && (Attributes == Rhs.Attributes) &&
(ModificationTicks == Rhs.ModificationTicks) && (Paths.size() == Rhs.Paths.size()))
{
- size_t PathCount = 0;
+ size_t PathCount = Paths.size();
for (size_t PathIndex = 0; PathIndex < PathCount; PathIndex++)
{
if (Paths[PathIndex].generic_string() != Rhs.Paths[PathIndex].generic_string())
@@ -1706,6 +1704,8 @@ namespace chunkedcontent_testutils {
} // namespace chunkedcontent_testutils
+TEST_SUITE_BEGIN("remotestore.chunkedcontent");
+
TEST_CASE("chunkedcontent.DeletePathsFromContent")
{
FastRandom BaseRandom;
@@ -1924,6 +1924,8 @@ TEST_CASE("chunkedcontent.ApplyChunkedContentOverlay")
}
}
+TEST_SUITE_END();
+
#endif // ZEN_WITH_TESTS
} // namespace zen
diff --git a/src/zenremotestore/chunking/chunkedfile.cpp b/src/zenremotestore/chunking/chunkedfile.cpp
index 652110605..633ddfd0d 100644
--- a/src/zenremotestore/chunking/chunkedfile.cpp
+++ b/src/zenremotestore/chunking/chunkedfile.cpp
@@ -211,6 +211,8 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
# if 0
+TEST_SUITE_BEGIN("remotestore.chunkedfile");
+
TEST_CASE("chunkedfile.findparams")
{
# if 1
@@ -513,6 +515,8 @@ TEST_CASE("chunkedfile.findparams")
// WorkLatch.CountDown();
// WorkLatch.Wait();
}
+
+TEST_SUITE_END();
# endif // 0
void
diff --git a/src/zenremotestore/chunking/chunkingcache.cpp b/src/zenremotestore/chunking/chunkingcache.cpp
index 7f0a26330..e9b783a00 100644
--- a/src/zenremotestore/chunking/chunkingcache.cpp
+++ b/src/zenremotestore/chunking/chunkingcache.cpp
@@ -75,13 +75,13 @@ public:
{
Lock.ReleaseNow();
RwLock::ExclusiveLockScope EditLock(m_Lock);
- if (auto RemoveIt = m_PathHashToEntry.find(PathHash); It != m_PathHashToEntry.end())
+ if (auto RemoveIt = m_PathHashToEntry.find(PathHash); RemoveIt != m_PathHashToEntry.end())
{
- CachedEntry& DeleteEntry = m_Entries[It->second];
+ CachedEntry& DeleteEntry = m_Entries[RemoveIt->second];
DeleteEntry.Chunked = {};
DeleteEntry.ModificationTick = 0;
- m_FreeEntryIndexes.push_back(It->second);
- m_PathHashToEntry.erase(It);
+ m_FreeEntryIndexes.push_back(RemoveIt->second);
+ m_PathHashToEntry.erase(RemoveIt);
}
}
}
@@ -461,6 +461,8 @@ namespace chunkingcache_testutils {
}
} // namespace chunkingcache_testutils
+TEST_SUITE_BEGIN("remotestore.chunkingcache");
+
TEST_CASE("chunkingcache.nullchunkingcache")
{
using namespace chunkingcache_testutils;
@@ -617,6 +619,8 @@ TEST_CASE("chunkingcache.diskchunkingcache")
}
}
+TEST_SUITE_END();
+
void
chunkingcache_forcelink()
{
diff --git a/src/zenremotestore/filesystemutils.cpp b/src/zenremotestore/filesystemutils.cpp
index fa1ce6f78..fdb2143d8 100644
--- a/src/zenremotestore/filesystemutils.cpp
+++ b/src/zenremotestore/filesystemutils.cpp
@@ -637,6 +637,8 @@ namespace {
void GenerateFile(const std::filesystem::path& Path) { BasicFile _(Path, BasicFile::Mode::kTruncate); }
} // namespace
+TEST_SUITE_BEGIN("remotestore.filesystemutils");
+
TEST_CASE("filesystemutils.CleanDirectory")
{
ScopedTemporaryDirectory TmpDir;
@@ -692,6 +694,8 @@ TEST_CASE("filesystemutils.CleanDirectory")
CHECK(!IsFile(TmpDir.Path() / "CantDeleteMe2" / "deleteme"));
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h
index 85dabc59f..da8437a58 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h
@@ -53,15 +53,24 @@ public:
std::function<IoBuffer(uint64_t Offset, uint64_t Size)>&& Transmitter,
std::function<void(uint64_t, bool)>&& OnSentBytes) = 0;
- virtual IoBuffer GetBuildBlob(const Oid& BuildId,
- const IoHash& RawHash,
- uint64_t RangeOffset = 0,
- uint64_t RangeBytes = (uint64_t)-1) = 0;
+ virtual IoBuffer GetBuildBlob(const Oid& BuildId,
+ const IoHash& RawHash,
+ uint64_t RangeOffset = 0,
+ uint64_t RangeBytes = (uint64_t)-1) = 0;
+
+ struct BuildBlobRanges
+ {
+ IoBuffer PayloadBuffer;
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+ };
+ virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId,
+ const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) = 0;
virtual std::vector<std::function<void()>> GetLargeBuildBlob(const Oid& BuildId,
const IoHash& RawHash,
uint64_t ChunkSize,
std::function<void(uint64_t Offset, const IoBuffer& Chunk)>&& OnReceive,
- std::function<void()>&& OnComplete) = 0;
+ std::function<void()>&& OnComplete) = 0;
[[nodiscard]] virtual bool PutBlockMetadata(const Oid& BuildId, const IoHash& BlockRawHash, const CbObject& MetaData) = 0;
virtual CbObject FindBlocks(const Oid& BuildId, uint64_t MaxBlockCount) = 0;
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h
index bb5b1c5f4..24702df0f 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h
@@ -37,6 +37,14 @@ public:
const IoHash& RawHash,
uint64_t RangeOffset = 0,
uint64_t RangeBytes = (uint64_t)-1) = 0;
+ struct BuildBlobRanges
+ {
+ IoBuffer PayloadBuffer;
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+ };
+ virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId,
+ const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) = 0;
virtual void PutBlobMetadatas(const Oid& BuildId, std::span<const IoHash> BlobHashes, std::span<const CbObject> MetaDatas) = 0;
virtual std::vector<CbObject> GetBlobMetadatas(const Oid& BuildId, std::span<const IoHash> BlobHashes) = 0;
@@ -61,10 +69,19 @@ std::unique_ptr<BuildStorageCache> CreateZenBuildStorageCache(HttpClient& H
const std::filesystem::path& TempFolderPath,
WorkerThreadPool& BackgroundWorkerPool);
+#if ZEN_WITH_TESTS
+std::unique_ptr<BuildStorageCache> CreateInMemoryBuildStorageCache(uint64_t MaxRangeSupported,
+ BuildStorageCache::Statistics& Stats,
+ double LatencySec = 0.0,
+ double DelayPerKBSec = 0.0);
+#endif // ZEN_WITH_TESTS
+
struct ZenCacheEndpointTestResult
{
bool Success = false;
std::string FailureReason;
+ double LatencySeconds = -1.0;
+ uint64_t MaxRangeCountPerRequest = 1;
};
ZenCacheEndpointTestResult TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose);
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
index 6304159ae..0d2eded58 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
@@ -7,7 +7,9 @@
#include <zencore/uid.h>
#include <zencore/zencore.h>
#include <zenremotestore/builds/buildstoragecache.h>
+#include <zenremotestore/chunking/chunkblock.h>
#include <zenremotestore/chunking/chunkedcontent.h>
+#include <zenremotestore/partialblockrequestmode.h>
#include <zenutil/bufferedwritefilecache.h>
#include <atomic>
@@ -108,17 +110,6 @@ struct RebuildFolderStateStatistics
uint64_t FinalizeTreeElapsedWallTimeUs = 0;
};
-enum EPartialBlockRequestMode
-{
- Off,
- ZenCacheOnly,
- Mixed,
- All,
- Invalid
-};
-
-EPartialBlockRequestMode PartialBlockRequestModeFromString(const std::string_view ModeString);
-
std::filesystem::path ZenStateFilePath(const std::filesystem::path& ZenFolderPath);
std::filesystem::path ZenTempFolderPath(const std::filesystem::path& ZenFolderPath);
@@ -170,7 +161,7 @@ public:
DownloadStatistics m_DownloadStats;
WriteChunkStatistics m_WriteChunkStats;
RebuildFolderStateStatistics m_RebuildFolderStateStats;
- std::atomic<uint64_t> m_WrittenChunkByteCount;
+ std::atomic<uint64_t> m_WrittenChunkByteCount = 0;
private:
struct BlockWriteOps
@@ -195,7 +186,7 @@ private:
uint32_t ScavengedContentIndex = (uint32_t)-1;
uint32_t ScavengedPathIndex = (uint32_t)-1;
uint32_t RemoteSequenceIndex = (uint32_t)-1;
- uint64_t RawSize = (uint32_t)-1;
+ uint64_t RawSize = (uint64_t)-1;
};
struct CopyChunkData
@@ -218,33 +209,6 @@ private:
uint64_t ElapsedTimeMs = 0;
};
- struct BlockRangeDescriptor
- {
- uint32_t BlockIndex = (uint32_t)-1;
- uint64_t RangeStart = 0;
- uint64_t RangeLength = 0;
- uint32_t ChunkBlockIndexStart = 0;
- uint32_t ChunkBlockIndexCount = 0;
- };
-
- struct BlockRangeLimit
- {
- uint16_t SizePercent;
- uint16_t MaxRangeCount;
- };
-
- static constexpr uint16_t FullBlockRangePercentLimit = 95;
-
- static constexpr BuildsOperationUpdateFolder::BlockRangeLimit ForceMergeLimits[] = {
- {.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 1},
- {.SizePercent = 90, .MaxRangeCount = 2},
- {.SizePercent = 85, .MaxRangeCount = 8},
- {.SizePercent = 80, .MaxRangeCount = 16},
- {.SizePercent = 70, .MaxRangeCount = 32},
- {.SizePercent = 60, .MaxRangeCount = 48},
- {.SizePercent = 2, .MaxRangeCount = 56},
- {.SizePercent = 0, .MaxRangeCount = 64}};
-
void ScanCacheFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedChunkHashesFound,
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedSequenceHashesFound);
void ScanTempBlocksFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedBlocksFound);
@@ -299,25 +263,14 @@ private:
ParallelWork& Work,
std::function<void(IoBuffer&& Payload)>&& OnDownloaded);
- BlockRangeDescriptor MergeBlockRanges(std::span<const BlockRangeDescriptor> Ranges);
- std::optional<std::vector<BlockRangeDescriptor>> MakeOptionalBlockRangeVector(uint64_t TotalBlockSize,
- const BlockRangeDescriptor& Range);
- const BlockRangeLimit* GetBlockRangeLimitForRange(std::span<const BlockRangeLimit> Limits,
- uint64_t TotalBlockSize,
- std::span<const BlockRangeDescriptor> Ranges);
- std::vector<BlockRangeDescriptor> CollapseBlockRanges(const uint64_t AlwaysAcceptableGap,
- std::span<const BlockRangeDescriptor> BlockRanges);
- uint64_t CalculateNextGap(std::span<const BlockRangeDescriptor> BlockRanges);
- std::optional<std::vector<BlockRangeDescriptor>> CalculateBlockRanges(uint32_t BlockIndex,
- const ChunkBlockDescription& BlockDescription,
- std::span<const uint32_t> BlockChunkIndexNeeded,
- bool LimitToSingleRange,
- const uint64_t ChunkStartOffsetInBlock,
- const uint64_t TotalBlockSize,
- uint64_t& OutTotalWantedChunksSize);
- void DownloadPartialBlock(const BlockRangeDescriptor BlockRange,
- const BlobsExistsResult& ExistsResult,
- std::function<void(IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath)>&& OnDownloaded);
+ void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges,
+ size_t BlockRangeIndex,
+ size_t BlockRangeCount,
+ const BlobsExistsResult& ExistsResult,
+ std::function<void(IoBuffer&& InMemoryBuffer,
+ const std::filesystem::path& OnDiskPath,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>&& OnDownloaded);
std::vector<uint32_t> WriteLocalChunkToCache(CloneQueryInterface* CloneQuery,
const CopyChunkData& CopyData,
@@ -339,7 +292,8 @@ private:
const uint64_t FileOffset,
const uint32_t PathIndex);
- bool GetBlockWriteOps(std::span<const IoHash> ChunkRawHashes,
+ bool GetBlockWriteOps(const IoHash& BlockRawHash,
+ std::span<const IoHash> ChunkRawHashes,
std::span<const uint32_t> ChunkCompressedLengths,
std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters,
std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags,
@@ -408,7 +362,7 @@ private:
const std::filesystem::path m_TempDownloadFolderPath;
const std::filesystem::path m_TempBlockFolderPath;
- std::atomic<uint64_t> m_ValidatedChunkByteCount;
+ std::atomic<uint64_t> m_ValidatedChunkByteCount = 0;
};
struct FindBlocksStatistics
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h
index ab3037c89..7306188ca 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h
@@ -14,13 +14,20 @@ class BuildStorageCache;
struct BuildStorageResolveResult
{
- std::string HostUrl;
- std::string HostName;
- bool HostAssumeHttp2 = false;
-
- std::string CacheUrl;
- std::string CacheName;
- bool CacheAssumeHttp2 = false;
+ struct Capabilities
+ {
+ uint64_t MaxRangeCountPerRequest = 1;
+ };
+ struct Host
+ {
+ std::string Address;
+ std::string Name;
+ bool AssumeHttp2 = false;
+ double LatencySec = -1.0;
+ Capabilities Caps;
+ };
+ Host Cloud;
+ Host Cache;
};
enum class ZenCacheResolveMode
@@ -43,7 +50,6 @@ std::vector<ChunkBlockDescription> GetBlockDescriptions(OperationLogOutput& Out
BuildStorageBase& Storage,
BuildStorageCache* OptionalCacheStorage,
const Oid& BuildId,
- const Oid& BuildPartId,
std::span<const IoHash> BlockRawHashes,
bool AttemptFallback,
bool IsQuiet,
@@ -51,12 +57,13 @@ std::vector<ChunkBlockDescription> GetBlockDescriptions(OperationLogOutput& Out
struct StorageInstance
{
- std::unique_ptr<HttpClient> BuildStorageHttp;
- std::unique_ptr<BuildStorageBase> BuildStorage;
- std::string StorageName;
+ BuildStorageResolveResult::Host BuildStorageHost;
+ std::unique_ptr<HttpClient> BuildStorageHttp;
+ std::unique_ptr<BuildStorageBase> BuildStorage;
+
+ BuildStorageResolveResult::Host CacheHost;
std::unique_ptr<HttpClient> CacheHttp;
- std::unique_ptr<BuildStorageCache> BuildCacheStorage;
- std::string CacheName;
+ std::unique_ptr<BuildStorageCache> CacheStorage;
};
} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h
index d339b0f94..931bb2097 100644
--- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h
+++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h
@@ -7,8 +7,9 @@
#include <zencore/compactbinary.h>
#include <zencore/compress.h>
-#include <optional>
-#include <vector>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
@@ -20,13 +21,14 @@ struct ThinChunkBlockDescription
struct ChunkBlockDescription : public ThinChunkBlockDescription
{
- uint64_t HeaderSize;
+ uint64_t HeaderSize = 0;
std::vector<uint32_t> ChunkRawLengths;
std::vector<uint32_t> ChunkCompressedLengths;
};
std::vector<ChunkBlockDescription> ParseChunkBlockDescriptionList(const CbObjectView& BlocksObject);
ChunkBlockDescription ParseChunkBlockDescription(const CbObjectView& BlockObject);
+std::vector<ChunkBlockDescription> ParseBlockMetadatas(std::span<const CbObject> BlockMetadatas);
CbObject BuildChunkBlockDescription(const ChunkBlockDescription& Block, CbObjectView MetaData);
ChunkBlockDescription GetChunkBlockDescription(const SharedBuffer& BlockPayload, const IoHash& RawHash);
typedef std::function<std::pair<uint64_t, CompressedBuffer>(const IoHash& RawHash)> FetchChunkFunc;
@@ -73,6 +75,70 @@ std::vector<size_t> FindReuseBlocks(OperationLogOutput& Output,
std::span<const uint32_t> ChunkIndexes,
std::vector<uint32_t>& OutUnusedChunkIndexes);
+class ChunkBlockAnalyser
+{
+public:
+ struct Options
+ {
+ bool IsQuiet = false;
+ bool IsVerbose = false;
+ double HostLatencySec = -1.0;
+ double HostHighSpeedLatencySec = -1.0;
+ uint64_t HostSpeedBytesPerSec = (1u * 1024u * 1024u * 1024u) / 8u; // 1GBit
+ uint64_t HostHighSpeedBytesPerSec = (2u * 1024u * 1024u * 1024u) / 8u; // 2GBit
+ uint64_t HostMaxRangeCountPerRequest = (uint64_t)-1;
+ uint64_t HostHighSpeedMaxRangeCountPerRequest = (uint64_t)-1; // No limit
+ uint64_t MaxRangesPerBlock = 1024u;
+ };
+
+ ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options);
+
+ struct BlockRangeDescriptor
+ {
+ uint32_t BlockIndex = (uint32_t)-1;
+ uint64_t RangeStart = 0;
+ uint64_t RangeLength = 0;
+ uint32_t ChunkBlockIndexStart = 0;
+ uint32_t ChunkBlockIndexCount = 0;
+ };
+
+ struct NeededBlock
+ {
+ uint32_t BlockIndex;
+ std::vector<uint32_t> ChunkIndexes;
+ };
+
+ std::vector<NeededBlock> GetNeeded(const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& ChunkHashToChunkIndex,
+ std::function<bool(uint32_t ChunkIndex)>&& NeedsBlockChunk);
+
+ enum class EPartialBlockDownloadMode
+ {
+ Off,
+ SingleRange,
+ MultiRange,
+ MultiRangeHighSpeed,
+ Exact
+ };
+
+ struct BlockResult
+ {
+ std::vector<BlockRangeDescriptor> BlockRanges;
+ std::vector<uint32_t> FullBlockIndexes;
+ };
+
+ BlockResult CalculatePartialBlockDownloads(std::span<const NeededBlock> NeededBlocks,
+ std::span<const EPartialBlockDownloadMode> BlockPartialDownloadModes);
+
+private:
+ OperationLogOutput& m_LogOutput;
+ const std::span<const ChunkBlockDescription> m_BlockDescriptions;
+ const Options m_Options;
+};
+
+#if ZEN_WITH_TESTS
+
void chunkblock_forcelink();
+#endif // ZEN_WITH_TESTS
+
} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h b/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h
index d402bd3f0..f44381e42 100644
--- a/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h
+++ b/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h
@@ -231,7 +231,7 @@ GetSequenceIndexForRawHash(const ChunkedContentLookup& Lookup, const IoHash& Raw
inline uint32_t
GetChunkIndexForRawHash(const ChunkedContentLookup& Lookup, const IoHash& RawHash)
{
- return Lookup.RawHashToSequenceIndex.at(RawHash);
+ return Lookup.ChunkHashToChunkIndex.at(RawHash);
}
inline uint32_t
diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h
index 432496bc1..caf7ecd28 100644
--- a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h
+++ b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h
@@ -2,6 +2,7 @@
#pragma once
+#include <cstdint>
#include <string>
#include <string_view>
#include <vector>
@@ -28,6 +29,8 @@ struct JupiterEndpointTestResult
{
bool Success = false;
std::string FailureReason;
+ double LatencySeconds = -1.0;
+ uint64_t MaxRangeCountPerRequest = 1;
};
JupiterEndpointTestResult TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose);
diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h b/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h
index eaf6962fd..8721bc37f 100644
--- a/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h
+++ b/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h
@@ -56,6 +56,11 @@ struct FinalizeBuildPartResult : JupiterResult
std::vector<IoHash> Needs;
};
+struct BuildBlobRangesResult : JupiterResult
+{
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+};
+
/**
* Context for performing Jupiter operations
*
@@ -135,6 +140,13 @@ public:
uint64_t Offset = 0,
uint64_t Size = (uint64_t)-1);
+ BuildBlobRangesResult GetBuildBlob(std::string_view Namespace,
+ std::string_view BucketId,
+ const Oid& BuildId,
+ const IoHash& Hash,
+ std::filesystem::path TempFolderPath,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges);
+
JupiterResult PutMultipartBuildBlob(std::string_view Namespace,
std::string_view BucketId,
const Oid& BuildId,
diff --git a/src/zenremotestore/include/zenremotestore/operationlogoutput.h b/src/zenremotestore/include/zenremotestore/operationlogoutput.h
index 9693e69cf..32b95f50f 100644
--- a/src/zenremotestore/include/zenremotestore/operationlogoutput.h
+++ b/src/zenremotestore/include/zenremotestore/operationlogoutput.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/fmtutils.h>
+#include <zencore/logbase.h>
namespace zen {
@@ -10,7 +11,7 @@ class OperationLogOutput
{
public:
virtual ~OperationLogOutput() {}
- virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) = 0;
+ virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) = 0;
virtual void SetLogOperationName(std::string_view Name) = 0;
virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0;
@@ -57,23 +58,19 @@ public:
virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) = 0;
};
-struct LoggerRef;
+OperationLogOutput* CreateStandardLogOutput(LoggerRef Log);
-OperationLogOutput* CreateStandardLogOutput(LoggerRef& Log);
-
-#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- OutputTarget.EmitLogMessage(InLevel, fmtstr, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ (OutputTarget).EmitLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
} while (false)
-#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) \
- ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Info, fmtstr, ##__VA_ARGS__)
-#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) \
- ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__)
-#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) \
- ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__)
+#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Info, fmtstr, ##__VA_ARGS__)
+#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Debug, fmtstr, ##__VA_ARGS__)
+#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Warn, fmtstr, ##__VA_ARGS__)
} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h b/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h
new file mode 100644
index 000000000..54adea2b2
--- /dev/null
+++ b/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h
@@ -0,0 +1,20 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <string_view>
+
+namespace zen {
+
+enum EPartialBlockRequestMode
+{
+ Off,
+ ZenCacheOnly,
+ Mixed,
+ All,
+ Invalid
+};
+
+EPartialBlockRequestMode PartialBlockRequestModeFromString(const std::string_view ModeString);
+
+} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h
index e8b7c15c0..c058e1c1f 100644
--- a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h
+++ b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h
@@ -2,6 +2,7 @@
#pragma once
+#include <zenhttp/httpclient.h>
#include <zenremotestore/projectstore/remoteprojectstore.h>
namespace zen {
@@ -10,9 +11,6 @@ class AuthMgr;
struct BuildsRemoteStoreOptions : RemoteStoreOptions
{
- std::string Host;
- std::string OverrideHost;
- std::string ZenHost;
std::string Namespace;
std::string Bucket;
Oid BuildId;
@@ -22,18 +20,16 @@ struct BuildsRemoteStoreOptions : RemoteStoreOptions
std::filesystem::path OidcExePath;
bool ForceDisableBlocks = false;
bool ForceDisableTempBlocks = false;
- bool AssumeHttp2 = false;
- bool PopulateCache = true;
IoBuffer MetaData;
size_t MaximumInMemoryDownloadSize = 1024u * 1024u;
};
-std::shared_ptr<RemoteProjectStore> CreateJupiterBuildsRemoteStore(LoggerRef InLog,
- const BuildsRemoteStoreOptions& Options,
- const std::filesystem::path& TempFilePath,
- bool Quiet,
- bool Unattended,
- bool Hidden,
- WorkerThreadPool& CacheBackgroundWorkerPool);
+struct BuildStorageResolveResult;
+
+std::shared_ptr<RemoteProjectStore> CreateJupiterBuildsRemoteStore(LoggerRef InLog,
+ const BuildStorageResolveResult& ResolveResult,
+ std::function<HttpClientAccessToken()>&& TokenProvider,
+ const BuildsRemoteStoreOptions& Options,
+ const std::filesystem::path& TempFilePath);
} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h
index 008f94351..084d975a2 100644
--- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h
+++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h
@@ -5,7 +5,9 @@
#include <zencore/jobqueue.h>
#include <zenstore/projectstore.h>
+#include <zenremotestore/builds/buildstoragecache.h>
#include <zenremotestore/chunking/chunkblock.h>
+#include <zenremotestore/partialblockrequestmode.h>
#include <unordered_set>
@@ -73,24 +75,35 @@ public:
std::vector<ChunkBlockDescription> Blocks;
};
+ struct GetBlockDescriptionsResult : public Result
+ {
+ std::vector<ChunkBlockDescription> Blocks;
+ };
+
+ struct LoadAttachmentRangesResult : public Result
+ {
+ IoBuffer Bytes;
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+ };
+
struct RemoteStoreInfo
{
- bool CreateBlocks;
- bool UseTempBlockFiles;
- bool AllowChunking;
+ bool CreateBlocks = false;
+ bool UseTempBlockFiles = false;
+ bool AllowChunking = false;
std::string ContainerName;
std::string Description;
};
struct Stats
{
- std::uint64_t m_SentBytes;
- std::uint64_t m_ReceivedBytes;
- std::uint64_t m_RequestTimeNS;
- std::uint64_t m_RequestCount;
- std::uint64_t m_PeakSentBytes;
- std::uint64_t m_PeakReceivedBytes;
- std::uint64_t m_PeakBytesPerSec;
+ std::uint64_t m_SentBytes = 0;
+ std::uint64_t m_ReceivedBytes = 0;
+ std::uint64_t m_RequestTimeNS = 0;
+ std::uint64_t m_RequestCount = 0;
+ std::uint64_t m_PeakSentBytes = 0;
+ std::uint64_t m_PeakReceivedBytes = 0;
+ std::uint64_t m_PeakBytesPerSec = 0;
};
struct ExtendedStats
@@ -111,12 +124,17 @@ public:
virtual FinalizeResult FinalizeContainer(const IoHash& RawHash) = 0;
virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Payloads) = 0;
- virtual LoadContainerResult LoadContainer() = 0;
- virtual GetKnownBlocksResult GetKnownBlocks() = 0;
- virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0;
- virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0;
+ virtual LoadContainerResult LoadContainer() = 0;
+ virtual GetKnownBlocksResult GetKnownBlocks() = 0;
+ virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes,
+ BuildStorageCache* OptionalCache,
+ const Oid& CacheBuildId) = 0;
+
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0;
- virtual void Flush() = 0;
+ virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) = 0;
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0;
};
struct RemoteStoreOptions
@@ -153,14 +171,15 @@ RemoteProjectStore::LoadContainerResult BuildContainer(
class JobContext;
-RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog,
- const CbObject& ContainerObject,
- const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments,
- const std::function<bool(const IoHash& RawHash)>& HasAttachment,
- const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
- const std::function<void(const IoHash& RawHash)>& OnNeedAttachment,
- const std::function<void(const ChunkedInfo& Chunked)>& OnChunkedAttachment,
- JobContext* OptionalContext);
+RemoteProjectStore::Result SaveOplogContainer(
+ ProjectStore::Oplog& Oplog,
+ const CbObject& ContainerObject,
+ const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(ThinChunkBlockDescription&& ThinBlockDescription, std::vector<uint32_t>&& NeededChunkIndexes)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment,
+ const std::function<void(const ChunkedInfo& Chunked)>& OnChunkedAttachment,
+ JobContext* OptionalContext);
RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore,
RemoteProjectStore& RemoteStore,
@@ -177,15 +196,29 @@ RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore,
bool IgnoreMissingAttachments,
JobContext* OptionalContext);
-RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore,
- RemoteProjectStore& RemoteStore,
- ProjectStore::Oplog& Oplog,
- WorkerThreadPool& NetworkWorkerPool,
- WorkerThreadPool& WorkerPool,
- bool ForceDownload,
- bool IgnoreMissingAttachments,
- bool CleanOplog,
- JobContext* OptionalContext);
+struct LoadOplogContext
+{
+ CidStore& ChunkStore;
+ RemoteProjectStore& RemoteStore;
+ BuildStorageCache* OptionalCache = nullptr;
+ Oid CacheBuildId = Oid::Zero;
+ BuildStorageCache::Statistics* OptionalCacheStats = nullptr;
+ ProjectStore::Oplog& Oplog;
+ WorkerThreadPool& NetworkWorkerPool;
+ WorkerThreadPool& WorkerPool;
+ bool ForceDownload = false;
+ bool IgnoreMissingAttachments = false;
+ bool CleanOplog = false;
+ EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::All;
+ bool PopulateCache = false;
+ double StoreLatencySec = -1.0;
+ uint64_t StoreMaxRangeCountPerRequest = 1;
+ double CacheLatencySec = -1.0;
+ uint64_t CacheMaxRangeCountPerRequest = 1;
+ JobContext* OptionalJobContext = nullptr;
+};
+
+RemoteProjectStore::Result LoadOplog(LoadOplogContext&& Context);
std::vector<IoHash> GetBlockHashesFromOplog(CbObjectView ContainerObject);
std::vector<ThinChunkBlockDescription> GetBlocksFromOplog(CbObjectView ContainerObject, std::span<const IoHash> IncludeBlockHashes);
diff --git a/src/zenremotestore/jupiter/jupiterhost.cpp b/src/zenremotestore/jupiter/jupiterhost.cpp
index 7706f00c2..314aafc78 100644
--- a/src/zenremotestore/jupiter/jupiterhost.cpp
+++ b/src/zenremotestore/jupiter/jupiterhost.cpp
@@ -59,7 +59,22 @@ TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool
HttpClient::Response TestResponse = TestHttpClient.Get("/health/live");
if (TestResponse.IsSuccess())
{
- return {.Success = true};
+ // TODO: dan.engelbrecht 20260305 - replace this naive nginx detection with proper capabilites end point once it exists in Jupiter
+ uint64_t MaxRangeCountPerRequest = 1;
+ if (auto It = TestResponse.Header.Entries.find("Server"); It != TestResponse.Header.Entries.end())
+ {
+ if (StrCaseCompare(It->second.c_str(), "nginx", 5) == 0)
+ {
+ MaxRangeCountPerRequest = 128u; // This leaves more than 2k header space for auth token etc
+ }
+ }
+ LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health/ready");
+
+ if (!LatencyResult.Success)
+ {
+ return {.Success = false, .FailureReason = LatencyResult.FailureReason};
+ }
+ return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds, .MaxRangeCountPerRequest = MaxRangeCountPerRequest};
}
return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")};
}
diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp
index 1bc6564ce..52f9eb678 100644
--- a/src/zenremotestore/jupiter/jupitersession.cpp
+++ b/src/zenremotestore/jupiter/jupitersession.cpp
@@ -852,6 +852,71 @@ JupiterSession::GetBuildBlob(std::string_view Namespace,
return detail::ConvertResponse(Response, "JupiterSession::GetBuildBlob"sv);
}
+BuildBlobRangesResult
+JupiterSession::GetBuildBlob(std::string_view Namespace,
+ std::string_view BucketId,
+ const Oid& BuildId,
+ const IoHash& Hash,
+ std::filesystem::path TempFolderPath,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges)
+{
+ HttpClient::KeyValueMap Headers;
+ if (!Ranges.empty())
+ {
+ ExtendableStringBuilder<512> SB;
+ for (const std::pair<uint64_t, uint64_t>& R : Ranges)
+ {
+ if (SB.Size() > 0)
+ {
+ SB << ", ";
+ }
+ SB << R.first << "-" << R.first + R.second - 1;
+ }
+ Headers.Entries.insert({"Range", fmt::format("bytes={}", SB.ToView())});
+ }
+ std::string Url = fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}?supportsRedirect={}",
+ Namespace,
+ BucketId,
+ BuildId,
+ Hash.ToHexString(),
+ m_AllowRedirect ? "true"sv : "false"sv);
+
+ HttpClient::Response Response = m_HttpClient.Download(Url, TempFolderPath, Headers);
+ if (Response.StatusCode == HttpResponseCode::RangeNotSatisfiable && Ranges.size() > 1)
+ {
+ // Requests to Jupiter that is not served via nginx (content not stored locally in the file system) can not serve multi-range
+ // requests (asp.net limitation) This rejection is not implemented as of 2026-03-02, it is in the backlog (@joakim.lindqvist)
+ // If we encounter this error we fall back to a single range which covers all the requested ranges
+ uint64_t RangeStart = Ranges.front().first;
+ uint64_t RangeEnd = Ranges.back().first + Ranges.back().second - 1;
+ Headers.Entries.insert_or_assign("Range", fmt::format("bytes={}-{}", RangeStart, RangeEnd));
+ Response = m_HttpClient.Download(Url, TempFolderPath, Headers);
+ }
+ if (Response.IsSuccess())
+ {
+ // If we get a redirect to S3 or a non-Jupiter endpoint the content type will not be correct, validate it and set it
+ if (m_AllowRedirect && (Response.ResponsePayload.GetContentType() == HttpContentType::kBinary))
+ {
+ IoHash ValidateRawHash;
+ uint64_t ValidateRawSize = 0;
+ if (!Headers.Entries.contains("Range"))
+ {
+ ZEN_ASSERT_SLOW(CompressedBuffer::ValidateCompressedHeader(Response.ResponsePayload,
+ ValidateRawHash,
+ ValidateRawSize,
+ /*OutOptionalTotalCompressedSize*/ nullptr));
+ ZEN_ASSERT_SLOW(ValidateRawHash == Hash);
+ ZEN_ASSERT_SLOW(ValidateRawSize > 0);
+ ZEN_UNUSED(ValidateRawHash, ValidateRawSize);
+ Response.ResponsePayload.SetContentType(ZenContentType::kCompressedBinary);
+ }
+ }
+ }
+ BuildBlobRangesResult Result = {detail::ConvertResponse(Response, "JupiterSession::GetBuildBlob"sv)};
+ Result.Ranges = Response.GetRanges(Ranges);
+ return Result;
+}
+
JupiterResult
JupiterSession::PutBlockMetadata(std::string_view Namespace,
std::string_view BucketId,
diff --git a/src/zenremotestore/operationlogoutput.cpp b/src/zenremotestore/operationlogoutput.cpp
index 0837ed716..5ed844c9d 100644
--- a/src/zenremotestore/operationlogoutput.cpp
+++ b/src/zenremotestore/operationlogoutput.cpp
@@ -3,6 +3,7 @@
#include <zenremotestore/operationlogoutput.h>
#include <zencore/logging.h>
+#include <zencore/logging/logger.h>
ZEN_THIRD_PARTY_INCLUDES_START
#include <gsl/gsl-lite.hpp>
@@ -30,13 +31,11 @@ class StandardLogOutput : public OperationLogOutput
{
public:
StandardLogOutput(LoggerRef& Log) : m_Log(Log) {}
- virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override
+ virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override
{
- if (m_Log.ShouldLog(LogLevel))
+ if (m_Log.ShouldLog(Point.Level))
{
- fmt::basic_memory_buffer<char, 250> MessageBuffer;
- fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args);
- ZEN_LOG(m_Log, LogLevel, "{}", std::string_view(MessageBuffer.data(), MessageBuffer.size()));
+ m_Log->Log(Point, Args);
}
}
@@ -47,7 +46,7 @@ public:
}
virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override
{
- const size_t PercentDone = StepCount > 0u ? gsl::narrow<uint8_t>((100 * StepIndex) / StepCount) : 0u;
+ [[maybe_unused]] const size_t PercentDone = StepCount > 0u ? gsl::narrow<uint8_t>((100 * StepIndex) / StepCount) : 0u;
ZEN_OPERATION_LOG_INFO(*this, "{}: {}%", m_LogOperationName, PercentDone);
}
virtual uint32_t GetProgressUpdateDelayMS() override { return 2000; }
@@ -59,13 +58,14 @@ public:
private:
LoggerRef m_Log;
std::string m_LogOperationName;
+ LoggerRef Log() { return m_Log; }
};
void
StandardLogOutputProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
{
ZEN_UNUSED(DoLinebreak);
- const size_t PercentDone =
+ [[maybe_unused]] const size_t PercentDone =
NewState.TotalCount > 0u ? gsl::narrow<uint8_t>((100 * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount) : 0u;
std::string Task = NewState.Task;
switch (NewState.Status)
@@ -95,7 +95,7 @@ StandardLogOutputProgressBar::Finish()
}
OperationLogOutput*
-CreateStandardLogOutput(LoggerRef& Log)
+CreateStandardLogOutput(LoggerRef Log)
{
return new StandardLogOutput(Log);
}
diff --git a/src/zenremotestore/partialblockrequestmode.cpp b/src/zenremotestore/partialblockrequestmode.cpp
new file mode 100644
index 000000000..b3edf515b
--- /dev/null
+++ b/src/zenremotestore/partialblockrequestmode.cpp
@@ -0,0 +1,27 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenremotestore/partialblockrequestmode.h>
+
+#include <zencore/string.h>
+
+namespace zen {
+
+EPartialBlockRequestMode
+PartialBlockRequestModeFromString(const std::string_view ModeString)
+{
+ switch (HashStringAsLowerDjb2(ModeString))
+ {
+ case HashStringDjb2("false"):
+ return EPartialBlockRequestMode::Off;
+ case HashStringDjb2("zencacheonly"):
+ return EPartialBlockRequestMode::ZenCacheOnly;
+ case HashStringDjb2("mixed"):
+ return EPartialBlockRequestMode::Mixed;
+ case HashStringDjb2("true"):
+ return EPartialBlockRequestMode::All;
+ default:
+ return EPartialBlockRequestMode::Invalid;
+ }
+}
+
+} // namespace zen
diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
index a8e883dde..2282a31dd 100644
--- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
@@ -7,8 +7,6 @@
#include <zencore/fmtutils.h>
#include <zencore/scopeguard.h>
-#include <zenhttp/httpclientauth.h>
-#include <zenremotestore/builds/buildstoragecache.h>
#include <zenremotestore/builds/buildstorageutil.h>
#include <zenremotestore/builds/jupiterbuildstorage.h>
#include <zenremotestore/operationlogoutput.h>
@@ -26,18 +24,14 @@ class BuildsRemoteStore : public RemoteProjectStore
public:
BuildsRemoteStore(LoggerRef InLog,
const HttpClientSettings& ClientSettings,
- HttpClientSettings* OptionalCacheClientSettings,
std::string_view HostUrl,
- std::string_view CacheUrl,
const std::filesystem::path& TempFilePath,
- WorkerThreadPool& CacheBackgroundWorkerPool,
std::string_view Namespace,
std::string_view Bucket,
const Oid& BuildId,
const IoBuffer& MetaData,
bool ForceDisableBlocks,
- bool ForceDisableTempBlocks,
- bool PopulateCache)
+ bool ForceDisableTempBlocks)
: m_Log(InLog)
, m_BuildStorageHttp(HostUrl, ClientSettings)
, m_BuildStorage(CreateJupiterBuildStorage(Log(),
@@ -53,20 +47,8 @@ public:
, m_MetaData(MetaData)
, m_EnableBlocks(!ForceDisableBlocks)
, m_UseTempBlocks(!ForceDisableTempBlocks)
- , m_PopulateCache(PopulateCache)
{
m_MetaData.MakeOwned();
- if (OptionalCacheClientSettings)
- {
- ZEN_ASSERT(!CacheUrl.empty());
- m_BuildCacheStorageHttp = std::make_unique<HttpClient>(CacheUrl, *OptionalCacheClientSettings);
- m_BuildCacheStorage = CreateZenBuildStorageCache(*m_BuildCacheStorageHttp,
- m_StorageCacheStats,
- Namespace,
- Bucket,
- TempFilePath,
- CacheBackgroundWorkerPool);
- }
}
virtual RemoteStoreInfo GetInfo() const override
@@ -75,9 +57,8 @@ public:
.UseTempBlockFiles = m_UseTempBlocks,
.AllowChunking = true,
.ContainerName = fmt::format("{}/{}/{}", m_Namespace, m_Bucket, m_BuildId),
- .Description = fmt::format("[cloud] {}{}. SessionId: {}. {}/{}/{}"sv,
+ .Description = fmt::format("[cloud] {}. SessionId: {}. {}/{}/{}"sv,
m_BuildStorageHttp.GetBaseUri(),
- m_BuildCacheStorage ? fmt::format(" (Cache: {})", m_BuildCacheStorageHttp->GetBaseUri()) : ""sv,
m_BuildStorageHttp.GetSessionId(),
m_Namespace,
m_Bucket,
@@ -86,15 +67,13 @@ public:
virtual Stats GetStats() const override
{
- return {
- .m_SentBytes = m_BuildStorageStats.TotalBytesWritten.load() + m_StorageCacheStats.TotalBytesWritten.load(),
- .m_ReceivedBytes = m_BuildStorageStats.TotalBytesRead.load() + m_StorageCacheStats.TotalBytesRead.load(),
- .m_RequestTimeNS = m_BuildStorageStats.TotalRequestTimeUs.load() * 1000 + m_StorageCacheStats.TotalRequestTimeUs.load() * 1000,
- .m_RequestCount = m_BuildStorageStats.TotalRequestCount.load() + m_StorageCacheStats.TotalRequestCount.load(),
- .m_PeakSentBytes = Max(m_BuildStorageStats.PeakSentBytes.load(), m_StorageCacheStats.PeakSentBytes.load()),
- .m_PeakReceivedBytes = Max(m_BuildStorageStats.PeakReceivedBytes.load(), m_StorageCacheStats.PeakReceivedBytes.load()),
- .m_PeakBytesPerSec = Max(m_BuildStorageStats.PeakBytesPerSec.load(), m_StorageCacheStats.PeakBytesPerSec.load()),
- };
+ return {.m_SentBytes = m_BuildStorageStats.TotalBytesWritten.load(),
+ .m_ReceivedBytes = m_BuildStorageStats.TotalBytesRead.load(),
+ .m_RequestTimeNS = m_BuildStorageStats.TotalRequestTimeUs.load() * 1000,
+ .m_RequestCount = m_BuildStorageStats.TotalRequestCount.load(),
+ .m_PeakSentBytes = m_BuildStorageStats.PeakSentBytes.load(),
+ .m_PeakReceivedBytes = m_BuildStorageStats.PeakReceivedBytes.load(),
+ .m_PeakBytesPerSec = m_BuildStorageStats.PeakBytesPerSec.load()};
}
virtual bool GetExtendedStats(ExtendedStats& OutStats) const override
@@ -109,11 +88,6 @@ public:
}
Result = true;
}
- if (m_BuildCacheStorage)
- {
- OutStats.m_ReceivedBytesPerSource.insert_or_assign("Cache", m_StorageCacheStats.TotalBytesRead);
- Result = true;
- }
return Result;
}
@@ -441,7 +415,7 @@ public:
catch (const HttpClientError& Ex)
{
Result.ErrorCode = MakeErrorCode(Ex);
- Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'",
+ Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'",
m_BuildStorageHttp.GetBaseUri(),
m_Namespace,
m_Bucket,
@@ -451,7 +425,7 @@ public:
catch (const std::exception& Ex)
{
Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
- Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'",
+ Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'",
m_BuildStorageHttp.GetBaseUri(),
m_Namespace,
m_Bucket,
@@ -462,6 +436,53 @@ public:
return Result;
}
+ virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes,
+ BuildStorageCache* OptionalCache,
+ const Oid& CacheBuildId) override
+ {
+ std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log()));
+
+ ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero);
+ ZEN_ASSERT(OptionalCache == nullptr || CacheBuildId == m_BuildId);
+
+ GetBlockDescriptionsResult Result;
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; });
+
+ try
+ {
+ Result.Blocks = zen::GetBlockDescriptions(*Output,
+ *m_BuildStorage,
+ OptionalCache,
+ m_BuildId,
+ BlockHashes,
+ /*AttemptFallback*/ false,
+ /*IsQuiet*/ false,
+ /*IsVerbose)*/ false);
+ }
+ catch (const HttpClientError& Ex)
+ {
+ Result.ErrorCode = MakeErrorCode(Ex);
+ Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'",
+ m_BuildStorageHttp.GetBaseUri(),
+ m_Namespace,
+ m_Bucket,
+ m_BuildId,
+ Ex.what());
+ }
+ catch (const std::exception& Ex)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'",
+ m_BuildStorageHttp.GetBaseUri(),
+ m_Namespace,
+ m_Bucket,
+ m_BuildId,
+ Ex.what());
+ }
+ return Result;
+ }
+
virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
{
ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero);
@@ -472,44 +493,73 @@ public:
try
{
- if (m_BuildCacheStorage)
- {
- IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash);
- if (CachedBlob)
- {
- Result.Bytes = std::move(CachedBlob);
- }
- }
- if (!Result.Bytes)
+ Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash);
+ }
+ catch (const HttpClientError& Ex)
+ {
+ Result.ErrorCode = MakeErrorCode(Ex);
+ Result.Reason = fmt::format("Failed getting blob {}/{}/{}/{}/{}. Reason: '{}'",
+ m_BuildStorageHttp.GetBaseUri(),
+ m_Namespace,
+ m_Bucket,
+ m_BuildId,
+ RawHash,
+ Ex.what());
+ }
+ catch (const std::exception& Ex)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = fmt::format("Failed getting blob {}/{}/{}/{}/{}. Reason: '{}'",
+ m_BuildStorageHttp.GetBaseUri(),
+ m_Namespace,
+ m_Bucket,
+ m_BuildId,
+ RawHash,
+ Ex.what());
+ }
+
+ return Result;
+ }
+
+ virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_ASSERT(!Ranges.empty());
+ LoadAttachmentRangesResult Result;
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; });
+
+ try
+ {
+ BuildStorageBase::BuildBlobRanges BlobRanges = m_BuildStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges);
+ if (BlobRanges.PayloadBuffer)
{
- Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash);
- if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache)
- {
- m_BuildCacheStorage->PutBuildBlob(m_BuildId,
- RawHash,
- Result.Bytes.GetContentType(),
- CompositeBuffer(SharedBuffer(Result.Bytes)));
- }
+ Result.Bytes = std::move(BlobRanges.PayloadBuffer);
+ Result.Ranges = std::move(BlobRanges.Ranges);
}
}
catch (const HttpClientError& Ex)
{
Result.ErrorCode = MakeErrorCode(Ex);
- Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'",
+ Result.Reason = fmt::format("Failed getting {} ranges for blob {}/{}/{}/{}/{}. Reason: '{}'",
+ Ranges.size(),
m_BuildStorageHttp.GetBaseUri(),
m_Namespace,
m_Bucket,
m_BuildId,
+ RawHash,
Ex.what());
}
catch (const std::exception& Ex)
{
Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
- Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'",
+ Result.Reason = fmt::format("Failed getting {} ranges for blob {}/{}/{}/{}/{}. Reason: '{}'",
+ Ranges.size(),
m_BuildStorageHttp.GetBaseUri(),
m_Namespace,
m_Bucket,
m_BuildId,
+ RawHash,
Ex.what());
}
@@ -524,38 +574,6 @@ public:
std::vector<IoHash> AttachmentsLeftToFind = RawHashes;
- if (m_BuildCacheStorage)
- {
- std::vector<BuildStorageCache::BlobExistsResult> ExistCheck = m_BuildCacheStorage->BlobsExists(m_BuildId, RawHashes);
- if (ExistCheck.size() == RawHashes.size())
- {
- AttachmentsLeftToFind.clear();
- for (size_t BlobIndex = 0; BlobIndex < RawHashes.size(); BlobIndex++)
- {
- const IoHash& Hash = RawHashes[BlobIndex];
- const BuildStorageCache::BlobExistsResult& BlobExists = ExistCheck[BlobIndex];
- if (BlobExists.HasBody)
- {
- IoBuffer CachedPayload = m_BuildCacheStorage->GetBuildBlob(m_BuildId, Hash);
- if (CachedPayload)
- {
- Result.Chunks.emplace_back(
- std::pair<IoHash, CompressedBuffer>{Hash,
- CompressedBuffer::FromCompressedNoValidate(std::move(CachedPayload))});
- }
- else
- {
- AttachmentsLeftToFind.push_back(Hash);
- }
- }
- else
- {
- AttachmentsLeftToFind.push_back(Hash);
- }
- }
- }
- }
-
for (const IoHash& Hash : AttachmentsLeftToFind)
{
LoadAttachmentResult ChunkResult = LoadAttachment(Hash);
@@ -564,27 +582,12 @@ public:
return LoadAttachmentsResult{ChunkResult};
}
ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000)));
- if (m_BuildCacheStorage && ChunkResult.Bytes && m_PopulateCache)
- {
- m_BuildCacheStorage->PutBuildBlob(m_BuildId,
- Hash,
- ChunkResult.Bytes.GetContentType(),
- CompositeBuffer(SharedBuffer(ChunkResult.Bytes)));
- }
Result.Chunks.emplace_back(
std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))});
}
return Result;
}
- virtual void Flush() override
- {
- if (m_BuildCacheStorage)
- {
- m_BuildCacheStorage->Flush(100, [](intptr_t) { return false; });
- }
- }
-
private:
static int MakeErrorCode(const HttpClientError& Ex)
{
@@ -601,10 +604,6 @@ private:
HttpClient m_BuildStorageHttp;
std::unique_ptr<BuildStorageBase> m_BuildStorage;
- BuildStorageCache::Statistics m_StorageCacheStats;
- std::unique_ptr<HttpClient> m_BuildCacheStorageHttp;
- std::unique_ptr<BuildStorageCache> m_BuildCacheStorage;
-
const std::string m_Namespace;
const std::string m_Bucket;
const Oid m_BuildId;
@@ -613,120 +612,35 @@ private:
const bool m_EnableBlocks = true;
const bool m_UseTempBlocks = true;
const bool m_AllowRedirect = false;
- const bool m_PopulateCache = true;
};
std::shared_ptr<RemoteProjectStore>
-CreateJupiterBuildsRemoteStore(LoggerRef InLog,
- const BuildsRemoteStoreOptions& Options,
- const std::filesystem::path& TempFilePath,
- bool Quiet,
- bool Unattended,
- bool Hidden,
- WorkerThreadPool& CacheBackgroundWorkerPool)
+CreateJupiterBuildsRemoteStore(LoggerRef InLog,
+ const BuildStorageResolveResult& ResolveResult,
+ std::function<HttpClientAccessToken()>&& TokenProvider,
+ const BuildsRemoteStoreOptions& Options,
+ const std::filesystem::path& TempFilePath)
{
- std::string Host = Options.Host;
- if (!Host.empty() && Host.find("://"sv) == std::string::npos)
- {
- // Assume https URL
- Host = fmt::format("https://{}"sv, Host);
- }
- std::string OverrideUrl = Options.OverrideHost;
- if (!OverrideUrl.empty() && OverrideUrl.find("://"sv) == std::string::npos)
- {
- // Assume https URL
- OverrideUrl = fmt::format("https://{}"sv, OverrideUrl);
- }
- std::string ZenHost = Options.ZenHost;
- if (!ZenHost.empty() && ZenHost.find("://"sv) == std::string::npos)
- {
- // Assume https URL
- ZenHost = fmt::format("https://{}"sv, ZenHost);
- }
-
- // 1) openid-provider if given (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider
- // 2) Access token as parameter in request
- // 3) Environment variable (different win vs linux/mac)
- // 4) Default openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider
-
- std::function<HttpClientAccessToken()> TokenProvider;
- if (!Options.OpenIdProvider.empty())
- {
- TokenProvider = httpclientauth::CreateFromOpenIdProvider(Options.AuthManager, Options.OpenIdProvider);
- }
- else if (!Options.AccessToken.empty())
- {
- TokenProvider = httpclientauth::CreateFromStaticToken(Options.AccessToken);
- }
- else if (!Options.OidcExePath.empty())
- {
- if (auto TokenProviderMaybe = httpclientauth::CreateFromOidcTokenExecutable(Options.OidcExePath,
- Host.empty() ? OverrideUrl : Host,
- Quiet,
- Unattended,
- Hidden);
- TokenProviderMaybe)
- {
- TokenProvider = TokenProviderMaybe.value();
- }
- }
-
- if (!TokenProvider)
- {
- TokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(Options.AuthManager);
- }
-
- BuildStorageResolveResult ResolveRes;
- {
- HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient",
- .AccessTokenProvider = TokenProvider,
- .AssumeHttp2 = Options.AssumeHttp2,
- .AllowResume = true,
- .RetryCount = 2};
-
- std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(InLog));
-
- ResolveRes =
- ResolveBuildStorage(*Output, ClientSettings, Host, OverrideUrl, ZenHost, ZenCacheResolveMode::Discovery, /*Verbose*/ false);
- }
-
HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient",
.ConnectTimeout = std::chrono::milliseconds(3000),
.Timeout = std::chrono::milliseconds(1800000),
.AccessTokenProvider = std::move(TokenProvider),
- .AssumeHttp2 = ResolveRes.HostAssumeHttp2,
+ .AssumeHttp2 = ResolveResult.Cloud.AssumeHttp2,
.AllowResume = true,
.RetryCount = 4,
.MaximumInMemoryDownloadSize = Options.MaximumInMemoryDownloadSize};
- std::unique_ptr<HttpClientSettings> CacheClientSettings;
-
- if (!ResolveRes.CacheUrl.empty())
- {
- CacheClientSettings =
- std::make_unique<HttpClientSettings>(HttpClientSettings{.LogCategory = "httpcacheclient",
- .ConnectTimeout = std::chrono::milliseconds{3000},
- .Timeout = std::chrono::milliseconds{30000},
- .AssumeHttp2 = ResolveRes.CacheAssumeHttp2,
- .AllowResume = true,
- .RetryCount = 0,
- .MaximumInMemoryDownloadSize = Options.MaximumInMemoryDownloadSize});
- }
-
std::shared_ptr<RemoteProjectStore> RemoteStore = std::make_shared<BuildsRemoteStore>(InLog,
ClientSettings,
- CacheClientSettings.get(),
- ResolveRes.HostUrl,
- ResolveRes.CacheUrl,
+ ResolveResult.Cloud.Address,
TempFilePath,
- CacheBackgroundWorkerPool,
Options.Namespace,
Options.Bucket,
Options.BuildId,
Options.MetaData,
Options.ForceDisableBlocks,
- Options.ForceDisableTempBlocks,
- Options.PopulateCache);
+ Options.ForceDisableTempBlocks);
+
return RemoteStore;
}
diff --git a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp
index 3a67d3842..bb21de12c 100644
--- a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp
@@ -7,8 +7,12 @@
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
#include <zencore/timer.h>
#include <zenhttp/httpcommon.h>
+#include <zenremotestore/builds/buildstoragecache.h>
+
+#include <numeric>
namespace zen {
@@ -74,9 +78,11 @@ public:
virtual SaveResult SaveContainer(const IoBuffer& Payload) override
{
- Stopwatch Timer;
SaveResult Result;
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
+
{
CbObject ContainerObject = LoadCompactBinaryObject(Payload);
@@ -87,6 +93,10 @@ public:
{
Result.Needs.insert(AttachmentHash);
}
+ else if (std::filesystem::path AttachmentMetaPath = GetAttachmentMetaPath(AttachmentHash); IsFile(AttachmentMetaPath))
+ {
+ BasicFile TouchIt(AttachmentMetaPath, BasicFile::Mode::kWrite);
+ }
});
}
@@ -112,14 +122,18 @@ public:
Result.Reason = fmt::format("Failed saving oplog container to '{}'. Reason: {}", ContainerPath, Ex.what());
}
AddStats(Payload.GetSize(), 0, Timer.GetElapsedTimeUs() * 1000);
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
return Result;
}
- virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash, ChunkBlockDescription&&) override
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload,
+ const IoHash& RawHash,
+ ChunkBlockDescription&& BlockDescription) override
{
- Stopwatch Timer;
- SaveAttachmentResult Result;
+ SaveAttachmentResult Result;
+
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
+
std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
if (!IsFile(ChunkPath))
{
@@ -142,14 +156,33 @@ public:
Result.Reason = fmt::format("Failed saving oplog attachment to '{}'. Reason: {}", ChunkPath, Ex.what());
}
}
+ if (!Result.ErrorCode && BlockDescription.BlockHash != IoHash::Zero)
+ {
+ try
+ {
+ std::filesystem::path MetaPath = GetAttachmentMetaPath(RawHash);
+ CbObject MetaData = BuildChunkBlockDescription(BlockDescription, {});
+ SharedBuffer MetaBuffer = MetaData.GetBuffer();
+ BasicFile MetaFile;
+ MetaFile.Open(MetaPath, BasicFile::Mode::kTruncate);
+ MetaFile.Write(MetaBuffer.GetView(), 0);
+ }
+ catch (const std::exception& Ex)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = fmt::format("Failed saving block description to '{}'. Reason: {}", RawHash, Ex.what());
+ }
+ }
AddStats(Payload.GetSize(), 0, Timer.GetElapsedTimeUs() * 1000);
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
return Result;
}
virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
{
+ SaveAttachmentsResult Result;
+
Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
for (const SharedBuffer& Chunk : Chunks)
{
@@ -157,12 +190,10 @@ public:
SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash(), {});
if (ChunkResult.ErrorCode)
{
- ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
- return SaveAttachmentsResult{ChunkResult};
+ Result = SaveAttachmentsResult{ChunkResult};
+ break;
}
}
- SaveAttachmentsResult Result;
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
return Result;
}
@@ -172,21 +203,60 @@ public:
virtual GetKnownBlocksResult GetKnownBlocks() override
{
+ Stopwatch Timer;
if (m_OptionalBaseName.empty())
{
- return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent)}};
+ size_t MaxBlockCount = 10000;
+
+ GetKnownBlocksResult Result;
+
+ DirectoryContent Content;
+ GetDirectoryContent(
+ m_OutputPath,
+ DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | DirectoryContentFlags::IncludeModificationTick,
+ Content);
+ std::vector<size_t> RecentOrder(Content.Files.size());
+ std::iota(RecentOrder.begin(), RecentOrder.end(), 0u);
+ std::sort(RecentOrder.begin(), RecentOrder.end(), [&Content](size_t Lhs, size_t Rhs) {
+ return Content.FileModificationTicks[Lhs] > Content.FileModificationTicks[Rhs];
+ });
+
+ for (size_t FileIndex : RecentOrder)
+ {
+ std::filesystem::path MetaPath = Content.Files[FileIndex];
+ if (MetaPath.extension() == MetaExtension)
+ {
+ IoBuffer MetaFile = ReadFile(MetaPath).Flatten();
+ CbValidateError Err;
+ CbObject ValidatedObject = ValidateAndReadCompactBinaryObject(std::move(MetaFile), Err);
+ if (Err == CbValidateError::None)
+ {
+ ChunkBlockDescription Description = ParseChunkBlockDescription(ValidatedObject);
+ if (Description.BlockHash != IoHash::Zero)
+ {
+ Result.Blocks.emplace_back(std::move(Description));
+ if (Result.Blocks.size() == MaxBlockCount)
+ {
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
+ return Result;
}
LoadContainerResult LoadResult = LoadContainer(m_OptionalBaseName);
if (LoadResult.ErrorCode)
{
return GetKnownBlocksResult{LoadResult};
}
- Stopwatch Timer;
std::vector<IoHash> BlockHashes = GetBlockHashesFromOplog(LoadResult.ContainerObject);
if (BlockHashes.empty())
{
return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent),
- .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}};
+ .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}};
}
std::vector<IoHash> ExistingBlockHashes;
for (const IoHash& RawHash : BlockHashes)
@@ -200,15 +270,15 @@ public:
if (ExistingBlockHashes.empty())
{
return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent),
- .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}};
+ .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}};
}
std::vector<ThinChunkBlockDescription> ThinKnownBlocks = GetBlocksFromOplog(LoadResult.ContainerObject, ExistingBlockHashes);
- const size_t KnowBlockCount = ThinKnownBlocks.size();
+ const size_t KnownBlockCount = ThinKnownBlocks.size();
- GetKnownBlocksResult Result{{.ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}};
- Result.Blocks.resize(KnowBlockCount);
- for (size_t BlockIndex = 0; BlockIndex < KnowBlockCount; BlockIndex++)
+ GetKnownBlocksResult Result{{.ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}};
+ Result.Blocks.resize(KnownBlockCount);
+ for (size_t BlockIndex = 0; BlockIndex < KnownBlockCount; BlockIndex++)
{
Result.Blocks[BlockIndex].BlockHash = ThinKnownBlocks[BlockIndex].BlockHash;
Result.Blocks[BlockIndex].ChunkRawHashes = std::move(ThinKnownBlocks[BlockIndex].ChunkRawHashes);
@@ -217,16 +287,88 @@ public:
return Result;
}
+ virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes,
+ BuildStorageCache* OptionalCache,
+ const Oid& CacheBuildId) override
+ {
+ GetBlockDescriptionsResult Result;
+
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
+
+ Result.Blocks.reserve(BlockHashes.size());
+
+ uint64_t ByteCount = 0;
+
+ std::vector<ChunkBlockDescription> UnorderedList;
+ {
+ if (OptionalCache)
+ {
+ std::vector<CbObject> CacheBlockMetadatas = OptionalCache->GetBlobMetadatas(CacheBuildId, BlockHashes);
+ for (const CbObject& BlockObject : CacheBlockMetadatas)
+ {
+ ByteCount += BlockObject.GetSize();
+ }
+ UnorderedList = ParseBlockMetadatas(CacheBlockMetadatas);
+ }
+
+ tsl::robin_map<IoHash, size_t, IoHash::Hasher> BlockDescriptionLookup;
+ BlockDescriptionLookup.reserve(BlockHashes.size());
+ for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++)
+ {
+ const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex];
+ BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex);
+ }
+
+ if (UnorderedList.size() < BlockHashes.size())
+ {
+ for (const IoHash& RawHash : BlockHashes)
+ {
+ if (!BlockDescriptionLookup.contains(RawHash))
+ {
+ std::filesystem::path MetaPath = GetAttachmentMetaPath(RawHash);
+ IoBuffer MetaFile = ReadFile(MetaPath).Flatten();
+ ByteCount += MetaFile.GetSize();
+ CbValidateError Err;
+ CbObject ValidatedObject = ValidateAndReadCompactBinaryObject(std::move(MetaFile), Err);
+ if (Err == CbValidateError::None)
+ {
+ ChunkBlockDescription Description = ParseChunkBlockDescription(ValidatedObject);
+ if (Description.BlockHash != IoHash::Zero)
+ {
+ BlockDescriptionLookup.insert_or_assign(Description.BlockHash, UnorderedList.size());
+ UnorderedList.emplace_back(std::move(Description));
+ }
+ }
+ }
+ }
+ }
+
+ Result.Blocks.reserve(UnorderedList.size());
+ for (const IoHash& RawHash : BlockHashes)
+ {
+ if (auto It = BlockDescriptionLookup.find(RawHash); It != BlockDescriptionLookup.end())
+ {
+ Result.Blocks.emplace_back(std::move(UnorderedList[It->second]));
+ }
+ }
+ }
+ AddStats(0, ByteCount, Timer.GetElapsedTimeUs() * 1000);
+ return Result;
+ }
+
virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
{
- Stopwatch Timer;
- LoadAttachmentResult Result;
+ LoadAttachmentResult Result;
+
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
+
std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
if (!IsFile(ChunkPath))
{
Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string());
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
return Result;
}
{
@@ -235,7 +377,41 @@ public:
Result.Bytes = ChunkFile.ReadAll();
}
AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000);
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
+ return Result;
+ }
+
+ virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_ASSERT(!Ranges.empty());
+ LoadAttachmentRangesResult Result;
+
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
+
+ std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
+ if (!IsFile(ChunkPath))
+ {
+ Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
+ Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string());
+ return Result;
+ }
+ {
+ uint64_t Start = Ranges.front().first;
+ uint64_t Length = Ranges.back().first + Ranges.back().second - Ranges.front().first;
+ Result.Bytes = IoBufferBuilder::MakeFromFile(ChunkPath, Start, Length);
+ Result.Ranges.reserve(Ranges.size());
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ Result.Ranges.push_back(std::make_pair(Range.first - Start, Range.second));
+ }
+ }
+ AddStats(0,
+ std::accumulate(Result.Ranges.begin(),
+ Result.Ranges.end(),
+ uint64_t(0),
+ [](uint64_t Current, const std::pair<uint64_t, uint64_t>& Value) { return Current + Value.second; }),
+ Timer.GetElapsedTimeUs() * 1000);
return Result;
}
@@ -258,20 +434,20 @@ public:
return Result;
}
- virtual void Flush() override {}
-
private:
LoadContainerResult LoadContainer(const std::string& Name)
{
- Stopwatch Timer;
- LoadContainerResult Result;
+ LoadContainerResult Result;
+
+ Stopwatch Timer;
+ auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; });
+
std::filesystem::path SourcePath = m_OutputPath;
SourcePath.append(Name);
if (!IsFile(SourcePath))
{
Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
Result.Reason = fmt::format("Failed loading oplog container from '{}'. Reason: 'The file does not exist'", SourcePath.string());
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
return Result;
}
IoBuffer ContainerPayload;
@@ -285,18 +461,16 @@ private:
if (Result.ContainerObject = ValidateAndReadCompactBinaryObject(std::move(ContainerPayload), ValidateResult);
ValidateResult != CbValidateError::None || !Result.ContainerObject)
{
- Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
- Result.Reason = fmt::format("The file {} is not formatted as a compact binary object ('{}')",
- SourcePath.string(),
- ToString(ValidateResult));
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = fmt::format("The file {} is not formatted as a compact binary object ('{}')",
+ SourcePath.string(),
+ ToString(ValidateResult));
return Result;
}
- Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
return Result;
}
- std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const
+ std::filesystem::path GetAttachmentBasePath(const IoHash& RawHash) const
{
ExtendablePathBuilder<128> ShardedPath;
ShardedPath.Append(m_OutputPath.c_str());
@@ -315,6 +489,19 @@ private:
return ShardedPath.ToPath();
}
+ static constexpr std::string_view BlobExtension = ".blob";
+ static constexpr std::string_view MetaExtension = ".meta";
+
+ std::filesystem::path GetAttachmentPath(const IoHash& RawHash)
+ {
+ return GetAttachmentBasePath(RawHash).replace_extension(BlobExtension);
+ }
+
+ std::filesystem::path GetAttachmentMetaPath(const IoHash& RawHash)
+ {
+ return GetAttachmentBasePath(RawHash).replace_extension(MetaExtension);
+ }
+
void AddStats(uint64_t UploadedBytes, uint64_t DownloadedBytes, uint64_t ElapsedNS)
{
m_SentBytes.fetch_add(UploadedBytes);
diff --git a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp
index 462de2988..5b456cb4c 100644
--- a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp
@@ -212,13 +212,43 @@ public:
return Result;
}
+ virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes,
+ BuildStorageCache* OptionalCache,
+ const Oid& CacheBuildId) override
+ {
+ ZEN_UNUSED(BlockHashes, OptionalCache, CacheBuildId);
+ return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}};
+ }
+
virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
{
- JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect);
- JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath);
+ LoadAttachmentResult Result;
+ JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect);
+ JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath);
+ AddStats(GetResult);
+
+ Result = {ConvertResult(GetResult), std::move(GetResult.Response)};
+ if (GetResult.ErrorCode)
+ {
+ Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'",
+ m_JupiterClient->ServiceUrl(),
+ m_Namespace,
+ RawHash,
+ Result.Reason);
+ }
+ return Result;
+ }
+
+ virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_ASSERT(!Ranges.empty());
+ LoadAttachmentRangesResult Result;
+ JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect);
+ JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath);
AddStats(GetResult);
- LoadAttachmentResult Result{ConvertResult(GetResult), std::move(GetResult.Response)};
+ Result = LoadAttachmentRangesResult{ConvertResult(GetResult), std::move(GetResult.Response)};
if (GetResult.ErrorCode)
{
Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'",
@@ -227,6 +257,10 @@ public:
RawHash,
Result.Reason);
}
+ else
+ {
+ Result.Ranges = std::vector<std::pair<uint64_t, uint64_t>>(Ranges.begin(), Ranges.end());
+ }
return Result;
}
@@ -247,8 +281,6 @@ public:
return Result;
}
- virtual void Flush() override {}
-
private:
LoadContainerResult LoadContainer(const IoHash& Key)
{
diff --git a/src/zenremotestore/projectstore/projectstoreoperations.cpp b/src/zenremotestore/projectstore/projectstoreoperations.cpp
index becac3d4c..36dc4d868 100644
--- a/src/zenremotestore/projectstore/projectstoreoperations.cpp
+++ b/src/zenremotestore/projectstore/projectstoreoperations.cpp
@@ -426,19 +426,19 @@ ProjectStoreOperationDownloadAttachments::Execute()
auto GetBuildBlob = [this](const IoHash& RawHash, const std::filesystem::path& OutputPath) {
IoBuffer Payload;
- if (m_Storage.BuildCacheStorage)
+ if (m_Storage.CacheStorage)
{
- Payload = m_Storage.BuildCacheStorage->GetBuildBlob(m_State.GetBuildId(), RawHash);
+ Payload = m_Storage.CacheStorage->GetBuildBlob(m_State.GetBuildId(), RawHash);
}
if (!Payload)
{
Payload = m_Storage.BuildStorage->GetBuildBlob(m_State.GetBuildId(), RawHash);
- if (m_Storage.BuildCacheStorage && m_Options.PopulateCache)
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- m_Storage.BuildCacheStorage->PutBuildBlob(m_State.GetBuildId(),
- RawHash,
- Payload.GetContentType(),
- CompositeBuffer(SharedBuffer(Payload)));
+ m_Storage.CacheStorage->PutBuildBlob(m_State.GetBuildId(),
+ RawHash,
+ Payload.GetContentType(),
+ CompositeBuffer(SharedBuffer(Payload)));
}
}
uint64_t PayloadSize = Payload.GetSize();
diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp
index 8be8eb0df..247bd6cb9 100644
--- a/src/zenremotestore/projectstore/remoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp
@@ -14,6 +14,8 @@
#include <zencore/trace.h>
#include <zencore/workthreadpool.h>
#include <zenhttp/httpcommon.h>
+#include <zenremotestore/builds/buildstoragecache.h>
+#include <zenremotestore/chunking/chunkedcontent.h>
#include <zenremotestore/chunking/chunkedfile.h>
#include <zenremotestore/operationlogoutput.h>
#include <zenstore/cidstore.h>
@@ -123,14 +125,17 @@ namespace remotestore_impl {
return OptionalContext->IsCancelled();
}
- std::string GetStats(const RemoteProjectStore::Stats& Stats, uint64_t ElapsedWallTimeMS)
+ std::string GetStats(const RemoteProjectStore::Stats& Stats,
+ const BuildStorageCache::Statistics* OptionalCacheStats,
+ uint64_t ElapsedWallTimeMS)
{
- return fmt::format(
- "Sent: {} ({}bits/s) Recv: {} ({}bits/s)",
- NiceBytes(Stats.m_SentBytes),
- NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((Stats.m_SentBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u),
- NiceBytes(Stats.m_ReceivedBytes),
- NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((Stats.m_ReceivedBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u));
+ uint64_t SentBytes = Stats.m_SentBytes + (OptionalCacheStats ? OptionalCacheStats->TotalBytesWritten.load() : 0);
+ uint64_t ReceivedBytes = Stats.m_ReceivedBytes + (OptionalCacheStats ? OptionalCacheStats->TotalBytesRead.load() : 0);
+ return fmt::format("Sent: {} ({}bits/s) Recv: {} ({}bits/s)",
+ NiceBytes(SentBytes),
+ NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((SentBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u),
+ NiceBytes(ReceivedBytes),
+ NiceNum(ElapsedWallTimeMS > 0u ? static_cast<uint64_t>((ReceivedBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u));
}
void LogRemoteStoreStatsDetails(const RemoteProjectStore::Stats& Stats)
@@ -229,44 +234,66 @@ namespace remotestore_impl {
struct DownloadInfo
{
- uint64_t OplogSizeBytes = 0;
- std::atomic<uint64_t> AttachmentsDownloaded = 0;
- std::atomic<uint64_t> AttachmentBlocksDownloaded = 0;
- std::atomic<uint64_t> AttachmentBytesDownloaded = 0;
- std::atomic<uint64_t> AttachmentBlockBytesDownloaded = 0;
- std::atomic<uint64_t> AttachmentsStored = 0;
- std::atomic<uint64_t> AttachmentBytesStored = 0;
- std::atomic_size_t MissingAttachmentCount = 0;
+ uint64_t OplogSizeBytes = 0;
+ std::atomic<uint64_t> AttachmentsDownloaded = 0;
+ std::atomic<uint64_t> AttachmentBlocksDownloaded = 0;
+ std::atomic<uint64_t> AttachmentBlocksRangesDownloaded = 0;
+ std::atomic<uint64_t> AttachmentBytesDownloaded = 0;
+ std::atomic<uint64_t> AttachmentBlockBytesDownloaded = 0;
+ std::atomic<uint64_t> AttachmentBlockRangeBytesDownloaded = 0;
+ std::atomic<uint64_t> AttachmentsStored = 0;
+ std::atomic<uint64_t> AttachmentBytesStored = 0;
+ std::atomic_size_t MissingAttachmentCount = 0;
};
- void DownloadAndSaveBlockChunks(CidStore& ChunkStore,
- RemoteProjectStore& RemoteStore,
- bool IgnoreMissingAttachments,
- JobContext* OptionalContext,
- WorkerThreadPool& NetworkWorkerPool,
- WorkerThreadPool& WorkerPool,
- Latch& AttachmentsDownloadLatch,
- Latch& AttachmentsWriteLatch,
- AsyncRemoteResult& RemoteResult,
- DownloadInfo& Info,
- Stopwatch& LoadAttachmentsTimer,
- std::atomic_uint64_t& DownloadStartMS,
- const std::vector<IoHash>& Chunks)
+ class JobContextLogOutput : public OperationLogOutput
+ {
+ public:
+ JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {}
+ virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override
+ {
+ if (m_OptionalContext)
+ {
+ fmt::basic_memory_buffer<char, 250> MessageBuffer;
+ fmt::vformat_to(fmt::appender(MessageBuffer), Point.FormatString, Args);
+ remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size()));
+ }
+ }
+
+ virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); }
+ virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); }
+ virtual uint32_t GetProgressUpdateDelayMS() override { return 0; }
+ virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override
+ {
+ ZEN_UNUSED(InSubTask);
+ return nullptr;
+ }
+
+ private:
+ JobContext* m_OptionalContext;
+ };
+
+ void DownloadAndSaveBlockChunks(LoadOplogContext& Context,
+ Latch& AttachmentsDownloadLatch,
+ Latch& AttachmentsWriteLatch,
+ AsyncRemoteResult& RemoteResult,
+ DownloadInfo& Info,
+ Stopwatch& LoadAttachmentsTimer,
+ std::atomic_uint64_t& DownloadStartMS,
+ ThinChunkBlockDescription&& ThinBlockDescription,
+ std::vector<uint32_t>&& NeededChunkIndexes)
{
AttachmentsDownloadLatch.AddCount(1);
- NetworkWorkerPool.ScheduleWork(
- [&RemoteStore,
- &ChunkStore,
- &WorkerPool,
+ Context.NetworkWorkerPool.ScheduleWork(
+ [&Context,
&AttachmentsDownloadLatch,
&AttachmentsWriteLatch,
&RemoteResult,
- Chunks = Chunks,
+ ThinBlockDescription = std::move(ThinBlockDescription),
+ NeededChunkIndexes = std::move(NeededChunkIndexes),
&Info,
&LoadAttachmentsTimer,
- &DownloadStartMS,
- IgnoreMissingAttachments,
- OptionalContext]() {
+ &DownloadStartMS]() {
ZEN_TRACE_CPU("DownloadBlockChunks");
auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); });
@@ -276,34 +303,47 @@ namespace remotestore_impl {
}
try
{
+ std::vector<IoHash> Chunks;
+ Chunks.reserve(NeededChunkIndexes.size());
+ for (uint32_t ChunkIndex : NeededChunkIndexes)
+ {
+ Chunks.push_back(ThinBlockDescription.ChunkRawHashes[ChunkIndex]);
+ }
+
uint64_t Unset = (std::uint64_t)-1;
DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs());
- RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks);
+ RemoteProjectStore::LoadAttachmentsResult Result = Context.RemoteStore.LoadAttachments(Chunks);
if (Result.ErrorCode)
{
- ReportMessage(OptionalContext,
+ ReportMessage(Context.OptionalJobContext,
fmt::format("Failed to load attachments with {} chunks ({}): {}",
Chunks.size(),
RemoteResult.GetError(),
RemoteResult.GetErrorReason()));
Info.MissingAttachmentCount.fetch_add(1);
- if (IgnoreMissingAttachments)
+ if (Context.IgnoreMissingAttachments)
{
RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
}
return;
}
- Info.AttachmentsDownloaded.fetch_add(Chunks.size());
- ZEN_INFO("Loaded {} bulk attachments in {}",
- Chunks.size(),
- NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ Info.AttachmentsDownloaded.fetch_add(Result.Chunks.size());
+ for (const auto& It : Result.Chunks)
+ {
+ uint64_t ChunkSize = It.second.GetCompressedSize();
+ Info.AttachmentBytesDownloaded.fetch_add(ChunkSize);
+ }
+ remotestore_impl::ReportMessage(Context.OptionalJobContext,
+ fmt::format("Loaded {} bulk attachments in {}",
+ Chunks.size(),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))));
if (RemoteResult.IsError())
{
return;
}
AttachmentsWriteLatch.AddCount(1);
- WorkerPool.ScheduleWork(
- [&AttachmentsWriteLatch, &RemoteResult, &Info, &ChunkStore, Chunks = std::move(Result.Chunks)]() {
+ Context.WorkerPool.ScheduleWork(
+ [&AttachmentsWriteLatch, &RemoteResult, &Info, &Context, Chunks = std::move(Result.Chunks)]() {
auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); });
if (RemoteResult.IsError())
{
@@ -320,13 +360,13 @@ namespace remotestore_impl {
for (const auto& It : Chunks)
{
- uint64_t ChunkSize = It.second.GetCompressedSize();
- Info.AttachmentBytesDownloaded.fetch_add(ChunkSize);
WriteAttachmentBuffers.push_back(It.second.GetCompressed().Flatten().AsIoBuffer());
WriteRawHashes.push_back(It.first);
}
std::vector<CidStore::InsertResult> InsertResults =
- ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes, CidStore::InsertMode::kCopyOnly);
+ Context.ChunkStore.AddChunks(WriteAttachmentBuffers,
+ WriteRawHashes,
+ CidStore::InsertMode::kCopyOnly);
for (size_t Index = 0; Index < InsertResults.size(); Index++)
{
@@ -350,46 +390,38 @@ namespace remotestore_impl {
catch (const std::exception& Ex)
{
RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError),
- fmt::format("Failed to bulk load {} attachments", Chunks.size()),
+ fmt::format("Failed to bulk load {} attachments", NeededChunkIndexes.size()),
Ex.what());
}
},
WorkerThreadPool::EMode::EnableBacklog);
};
- void DownloadAndSaveBlock(CidStore& ChunkStore,
- RemoteProjectStore& RemoteStore,
- bool IgnoreMissingAttachments,
- JobContext* OptionalContext,
- WorkerThreadPool& NetworkWorkerPool,
- WorkerThreadPool& WorkerPool,
- Latch& AttachmentsDownloadLatch,
- Latch& AttachmentsWriteLatch,
- AsyncRemoteResult& RemoteResult,
- DownloadInfo& Info,
- Stopwatch& LoadAttachmentsTimer,
- std::atomic_uint64_t& DownloadStartMS,
- const IoHash& BlockHash,
- const std::vector<IoHash>& Chunks,
- uint32_t RetriesLeft)
+ void DownloadAndSaveBlock(LoadOplogContext& Context,
+ Latch& AttachmentsDownloadLatch,
+ Latch& AttachmentsWriteLatch,
+ AsyncRemoteResult& RemoteResult,
+ DownloadInfo& Info,
+ Stopwatch& LoadAttachmentsTimer,
+ std::atomic_uint64_t& DownloadStartMS,
+ const IoHash& BlockHash,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& AllNeededPartialChunkHashesLookup,
+ std::span<std::atomic<bool>> ChunkDownloadedFlags,
+ uint32_t RetriesLeft)
{
AttachmentsDownloadLatch.AddCount(1);
- NetworkWorkerPool.ScheduleWork(
+ Context.NetworkWorkerPool.ScheduleWork(
[&AttachmentsDownloadLatch,
&AttachmentsWriteLatch,
- &ChunkStore,
- &RemoteStore,
- &NetworkWorkerPool,
- &WorkerPool,
- BlockHash,
+ &Context,
&RemoteResult,
&Info,
&LoadAttachmentsTimer,
&DownloadStartMS,
- IgnoreMissingAttachments,
- OptionalContext,
RetriesLeft,
- Chunks = std::vector<IoHash>(Chunks)]() {
+ BlockHash = IoHash(BlockHash),
+ &AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags]() {
ZEN_TRACE_CPU("DownloadBlock");
auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); });
@@ -401,51 +433,65 @@ namespace remotestore_impl {
{
uint64_t Unset = (std::uint64_t)-1;
DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs());
- RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash);
- if (BlockResult.ErrorCode)
+
+ IoBuffer BlobBuffer;
+ if (Context.OptionalCache)
{
- ReportMessage(OptionalContext,
- fmt::format("Failed to download block attachment {} ({}): {}",
- BlockHash,
- RemoteResult.GetError(),
- RemoteResult.GetErrorReason()));
- Info.MissingAttachmentCount.fetch_add(1);
- if (!IgnoreMissingAttachments)
- {
- RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text);
- }
- return;
+ BlobBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, BlockHash);
}
- if (RemoteResult.IsError())
+
+ if (!BlobBuffer)
{
- return;
+ RemoteProjectStore::LoadAttachmentResult BlockResult = Context.RemoteStore.LoadAttachment(BlockHash);
+ if (BlockResult.ErrorCode)
+ {
+ ReportMessage(Context.OptionalJobContext,
+ fmt::format("Failed to download block attachment {} ({}): {}",
+ BlockHash,
+ BlockResult.Reason,
+ BlockResult.Text));
+ Info.MissingAttachmentCount.fetch_add(1);
+ if (!Context.IgnoreMissingAttachments)
+ {
+ RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text);
+ }
+ return;
+ }
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ BlobBuffer = std::move(BlockResult.Bytes);
+ ZEN_DEBUG("Loaded block attachment '{}' in {} ({})",
+ BlockHash,
+ NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000)),
+ NiceBytes(BlobBuffer.Size()));
+ if (Context.OptionalCache && Context.PopulateCache)
+ {
+ Context.OptionalCache->PutBuildBlob(Context.CacheBuildId,
+ BlockHash,
+ BlobBuffer.GetContentType(),
+ CompositeBuffer(SharedBuffer(BlobBuffer)));
+ }
}
- uint64_t BlockSize = BlockResult.Bytes.GetSize();
+ uint64_t BlockSize = BlobBuffer.GetSize();
Info.AttachmentBlocksDownloaded.fetch_add(1);
- ZEN_INFO("Loaded block attachment '{}' in {} ({})",
- BlockHash,
- NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000)),
- NiceBytes(BlockSize));
Info.AttachmentBlockBytesDownloaded.fetch_add(BlockSize);
AttachmentsWriteLatch.AddCount(1);
- WorkerPool.ScheduleWork(
+ Context.WorkerPool.ScheduleWork(
[&AttachmentsDownloadLatch,
&AttachmentsWriteLatch,
- &ChunkStore,
- &RemoteStore,
- &NetworkWorkerPool,
- &WorkerPool,
- BlockHash,
+ &Context,
&RemoteResult,
&Info,
&LoadAttachmentsTimer,
&DownloadStartMS,
- IgnoreMissingAttachments,
- OptionalContext,
RetriesLeft,
- Chunks = std::move(Chunks),
- Bytes = std::move(BlockResult.Bytes)]() {
+ BlockHash = IoHash(BlockHash),
+ &AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ Bytes = std::move(BlobBuffer)]() {
auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); });
if (RemoteResult.IsError())
{
@@ -454,64 +500,107 @@ namespace remotestore_impl {
try
{
ZEN_ASSERT(Bytes.Size() > 0);
- std::unordered_set<IoHash, IoHash::Hasher> WantedChunks;
- WantedChunks.reserve(Chunks.size());
- WantedChunks.insert(Chunks.begin(), Chunks.end());
std::vector<IoBuffer> WriteAttachmentBuffers;
std::vector<IoHash> WriteRawHashes;
IoHash RawHash;
uint64_t RawSize;
CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Bytes), RawHash, RawSize);
+
+ std::string ErrorString;
+
if (!Compressed)
{
- if (RetriesLeft > 0)
+ ErrorString =
+ fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash);
+ }
+ else if (RawHash != BlockHash)
+ {
+ ErrorString = fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash);
+ }
+ else if (CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); !BlockPayload)
+ {
+ ErrorString = fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash);
+ }
+ else
+ {
+ uint64_t PotentialSize = 0;
+ uint64_t UsedSize = 0;
+ uint64_t BlockSize = BlockPayload.GetSize();
+
+ uint64_t BlockHeaderSize = 0;
+
+ bool StoreChunksOK = IterateChunkBlock(
+ BlockPayload.Flatten(),
+ [&AllNeededPartialChunkHashesLookup,
+ &ChunkDownloadedFlags,
+ &WriteAttachmentBuffers,
+ &WriteRawHashes,
+ &Info,
+ &PotentialSize](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) {
+ auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(AttachmentRawHash);
+ if (ChunkIndexIt != AllNeededPartialChunkHashesLookup.end())
+ {
+ bool Expected = false;
+ if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true))
+ {
+ WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer());
+ IoHash RawHash;
+ uint64_t RawSize;
+ ZEN_ASSERT(CompressedBuffer::ValidateCompressedHeader(
+ WriteAttachmentBuffers.back(),
+ RawHash,
+ RawSize,
+ /*OutOptionalTotalCompressedSize*/ nullptr));
+ ZEN_ASSERT(RawHash == AttachmentRawHash);
+ WriteRawHashes.emplace_back(AttachmentRawHash);
+ PotentialSize += WriteAttachmentBuffers.back().GetSize();
+ }
+ }
+ },
+ BlockHeaderSize);
+
+ if (!StoreChunksOK)
{
- ReportMessage(
- OptionalContext,
- fmt::format(
- "Block attachment {} is malformed, can't parse as compressed binary, retrying download",
- BlockHash));
- return DownloadAndSaveBlock(ChunkStore,
- RemoteStore,
- IgnoreMissingAttachments,
- OptionalContext,
- NetworkWorkerPool,
- WorkerPool,
- AttachmentsDownloadLatch,
- AttachmentsWriteLatch,
- RemoteResult,
- Info,
- LoadAttachmentsTimer,
- DownloadStartMS,
- BlockHash,
- std::move(Chunks),
- RetriesLeft - 1);
+ ErrorString = fmt::format("Invalid format for block {}", BlockHash);
+ }
+ else
+ {
+ if (!WriteAttachmentBuffers.empty())
+ {
+ std::vector<CidStore::InsertResult> Results =
+ Context.ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes);
+ for (size_t Index = 0; Index < Results.size(); Index++)
+ {
+ const CidStore::InsertResult& Result = Results[Index];
+ if (Result.New)
+ {
+ Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize());
+ Info.AttachmentsStored.fetch_add(1);
+ UsedSize += WriteAttachmentBuffers[Index].GetSize();
+ }
+ }
+ if (UsedSize < BlockSize)
+ {
+ ZEN_DEBUG("Used {} (skipping {}) out of {} for block {} ({} %) (use of matching {}%)",
+ NiceBytes(UsedSize),
+ NiceBytes(BlockSize - UsedSize),
+ NiceBytes(BlockSize),
+ BlockHash,
+ (100 * UsedSize) / BlockSize,
+ PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0);
+ }
+ }
}
- ReportMessage(
- OptionalContext,
- fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash));
- RemoteResult.SetError(
- gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
- fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash),
- {});
- return;
}
- CompositeBuffer BlockPayload = Compressed.DecompressToComposite();
- if (!BlockPayload)
+
+ if (!ErrorString.empty())
{
if (RetriesLeft > 0)
{
- ReportMessage(
- OptionalContext,
- fmt::format("Block attachment {} is malformed, can't decompress payload, retrying download",
- BlockHash));
- return DownloadAndSaveBlock(ChunkStore,
- RemoteStore,
- IgnoreMissingAttachments,
- OptionalContext,
- NetworkWorkerPool,
- WorkerPool,
+ ReportMessage(Context.OptionalJobContext, fmt::format("{}, retrying download", ErrorString));
+
+ return DownloadAndSaveBlock(Context,
AttachmentsDownloadLatch,
AttachmentsWriteLatch,
RemoteResult,
@@ -519,91 +608,16 @@ namespace remotestore_impl {
LoadAttachmentsTimer,
DownloadStartMS,
BlockHash,
- std::move(Chunks),
+ AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
RetriesLeft - 1);
}
- ReportMessage(OptionalContext,
- fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash));
- RemoteResult.SetError(
- gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
- fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash),
- {});
- return;
- }
- if (RawHash != BlockHash)
- {
- ReportMessage(OptionalContext,
- fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash));
- RemoteResult.SetError(
- gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
- fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash),
- {});
- return;
- }
-
- uint64_t PotentialSize = 0;
- uint64_t UsedSize = 0;
- uint64_t BlockSize = BlockPayload.GetSize();
-
- uint64_t BlockHeaderSize = 0;
- bool StoreChunksOK = IterateChunkBlock(
- BlockPayload.Flatten(),
- [&WantedChunks, &WriteAttachmentBuffers, &WriteRawHashes, &Info, &PotentialSize](
- CompressedBuffer&& Chunk,
- const IoHash& AttachmentRawHash) {
- if (WantedChunks.contains(AttachmentRawHash))
- {
- WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer());
- IoHash RawHash;
- uint64_t RawSize;
- ZEN_ASSERT(
- CompressedBuffer::ValidateCompressedHeader(WriteAttachmentBuffers.back(),
- RawHash,
- RawSize,
- /*OutOptionalTotalCompressedSize*/ nullptr));
- ZEN_ASSERT(RawHash == AttachmentRawHash);
- WriteRawHashes.emplace_back(AttachmentRawHash);
- WantedChunks.erase(AttachmentRawHash);
- PotentialSize += WriteAttachmentBuffers.back().GetSize();
- }
- },
- BlockHeaderSize);
-
- if (!StoreChunksOK)
- {
- ReportMessage(OptionalContext,
- fmt::format("Block attachment {} has invalid format ({}): {}",
- BlockHash,
- RemoteResult.GetError(),
- RemoteResult.GetErrorReason()));
- RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
- fmt::format("Invalid format for block {}", BlockHash),
- {});
- return;
- }
-
- ZEN_ASSERT(WantedChunks.empty());
-
- if (!WriteAttachmentBuffers.empty())
- {
- auto Results = ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes);
- for (size_t Index = 0; Index < Results.size(); Index++)
+ else
{
- const auto& Result = Results[Index];
- if (Result.New)
- {
- Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize());
- Info.AttachmentsStored.fetch_add(1);
- UsedSize += WriteAttachmentBuffers[Index].GetSize();
- }
+ ReportMessage(Context.OptionalJobContext, ErrorString);
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), ErrorString, {});
+ return;
}
- ZEN_DEBUG("Used {} (matching {}) out of {} for block {} ({} %) (use of matching {}%)",
- NiceBytes(UsedSize),
- NiceBytes(PotentialSize),
- NiceBytes(BlockSize),
- BlockHash,
- (100 * UsedSize) / BlockSize,
- PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0);
}
}
catch (const std::exception& Ex)
@@ -618,19 +632,458 @@ namespace remotestore_impl {
catch (const std::exception& Ex)
{
RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError),
- fmt::format("Failed to block attachment {}", BlockHash),
+ fmt::format("Failed to download block attachment {}", BlockHash),
+ Ex.what());
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+ };
+
+ void DownloadPartialBlock(LoadOplogContext& Context,
+ AsyncRemoteResult& RemoteResult,
+ DownloadInfo& Info,
+ double& DownloadTimeSeconds,
+ const ChunkBlockDescription& BlockDescription,
+ bool BlockExistsInCache,
+ std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRangeDescriptors,
+ size_t BlockRangeIndexStart,
+ size_t BlockRangeCount,
+ std::function<void(IoBuffer&& Buffer,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths)>&& OnDownloaded)
+ {
+ ZEN_ASSERT(Context.StoreMaxRangeCountPerRequest != 0);
+ ZEN_ASSERT(BlockExistsInCache == false || Context.CacheMaxRangeCountPerRequest != 0);
+
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges;
+ Ranges.reserve(BlockRangeDescriptors.size());
+ for (size_t BlockRangeIndex = BlockRangeIndexStart; BlockRangeIndex < BlockRangeIndexStart + BlockRangeCount; BlockRangeIndex++)
+ {
+ const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRangeDescriptors[BlockRangeIndex];
+ Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength));
+ }
+
+ size_t SubBlockRangeCount = BlockRangeCount;
+ size_t SubRangeCountComplete = 0;
+ std::span<const std::pair<uint64_t, uint64_t>> RangesSpan(Ranges);
+
+ while (SubRangeCountComplete < SubBlockRangeCount)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+
+ size_t SubRangeStartIndex = BlockRangeIndexStart + SubRangeCountComplete;
+ if (BlockExistsInCache)
+ {
+ ZEN_ASSERT(Context.OptionalCache);
+ size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, Context.CacheMaxRangeCountPerRequest);
+
+ if (SubRangeCount == 1)
+ {
+ // Legacy single-range path, prefer that for max compatibility
+
+ const std::pair<uint64_t, uint64_t> SubRange = RangesSpan[SubRangeCountComplete];
+ Stopwatch CacheTimer;
+ IoBuffer PayloadBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId,
+ BlockDescription.BlockHash,
+ SubRange.first,
+ SubRange.second);
+ DownloadTimeSeconds += CacheTimer.GetElapsedTimeMs() / 1000.0;
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ if (PayloadBuffer)
+ {
+ OnDownloaded(std::move(PayloadBuffer),
+ SubRangeStartIndex,
+ std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)});
+ SubRangeCountComplete += SubRangeCount;
+ continue;
+ }
+ }
+ else
+ {
+ auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount);
+
+ Stopwatch CacheTimer;
+ BuildStorageCache::BuildBlobRanges RangeBuffers =
+ Context.OptionalCache->GetBuildBlobRanges(Context.CacheBuildId, BlockDescription.BlockHash, SubRanges);
+ DownloadTimeSeconds += CacheTimer.GetElapsedTimeMs() / 1000.0;
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ if (RangeBuffers.PayloadBuffer)
+ {
+ if (RangeBuffers.Ranges.empty())
+ {
+ SubRangeCount = Ranges.size() - SubRangeCountComplete;
+ OnDownloaded(std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangesSpan.subspan(SubRangeCountComplete, SubRangeCount));
+ SubRangeCountComplete += SubRangeCount;
+ continue;
+ }
+ else if (RangeBuffers.Ranges.size() == SubRangeCount)
+ {
+ OnDownloaded(std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangeBuffers.Ranges);
+ SubRangeCountComplete += SubRangeCount;
+ continue;
+ }
+ }
+ }
+ }
+
+ size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, Context.StoreMaxRangeCountPerRequest);
+
+ auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount);
+
+ RemoteProjectStore::LoadAttachmentRangesResult BlockResult =
+ Context.RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, SubRanges);
+ DownloadTimeSeconds += BlockResult.ElapsedSeconds;
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ if (BlockResult.ErrorCode || !BlockResult.Bytes)
+ {
+ ReportMessage(Context.OptionalJobContext,
+ fmt::format("Failed to download {} ranges from block attachment '{}' ({}): {}",
+ SubRanges.size(),
+ BlockDescription.BlockHash,
+ BlockResult.ErrorCode,
+ BlockResult.Reason));
+ Info.MissingAttachmentCount.fetch_add(1);
+ if (!Context.IgnoreMissingAttachments)
+ {
+ RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text);
+ break;
+ }
+ }
+ else
+ {
+ if (BlockResult.Ranges.empty())
+ {
+ // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3
+ // Use the whole payload for the remaining ranges
+
+ if (Context.OptionalCache && Context.PopulateCache)
+ {
+ Context.OptionalCache->PutBuildBlob(Context.CacheBuildId,
+ BlockDescription.BlockHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(std::vector<IoBuffer>{BlockResult.Bytes}));
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ }
+ SubRangeCount = Ranges.size() - SubRangeCountComplete;
+ OnDownloaded(std::move(BlockResult.Bytes),
+ SubRangeStartIndex,
+ RangesSpan.subspan(SubRangeCountComplete, SubRangeCount));
+ }
+ else
+ {
+ if (BlockResult.Ranges.size() != SubRanges.size())
+ {
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ fmt::format("Range response for block {} contains {} ranges, expected {} ranges",
+ BlockDescription.BlockHash,
+ BlockResult.Ranges.size(),
+ SubRanges.size()),
+ "");
+ break;
+ }
+ OnDownloaded(std::move(BlockResult.Bytes), SubRangeStartIndex, BlockResult.Ranges);
+ }
+ }
+
+ SubRangeCountComplete += SubRangeCount;
+ }
+ }
+
+ void DownloadAndSavePartialBlock(LoadOplogContext& Context,
+ Latch& AttachmentsDownloadLatch,
+ Latch& AttachmentsWriteLatch,
+ AsyncRemoteResult& RemoteResult,
+ DownloadInfo& Info,
+ Stopwatch& LoadAttachmentsTimer,
+ std::atomic_uint64_t& DownloadStartMS,
+ const ChunkBlockDescription& BlockDescription,
+ bool BlockExistsInCache,
+ std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRangeDescriptors,
+ size_t BlockRangeIndexStart,
+ size_t BlockRangeCount,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& AllNeededPartialChunkHashesLookup,
+ std::span<std::atomic<bool>> ChunkDownloadedFlags,
+ uint32_t RetriesLeft)
+ {
+ AttachmentsDownloadLatch.AddCount(1);
+ Context.NetworkWorkerPool.ScheduleWork(
+ [&AttachmentsDownloadLatch,
+ &AttachmentsWriteLatch,
+ &Context,
+ &RemoteResult,
+ &Info,
+ &LoadAttachmentsTimer,
+ &DownloadStartMS,
+ BlockDescription,
+ BlockExistsInCache,
+ BlockRangeDescriptors,
+ BlockRangeIndexStart,
+ BlockRangeCount,
+ &AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ RetriesLeft]() {
+ ZEN_TRACE_CPU("DownloadBlockRanges");
+
+ auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); });
+ try
+ {
+ uint64_t Unset = (std::uint64_t)-1;
+ DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs());
+
+ double DownloadElapsedSeconds = 0;
+ uint64_t DownloadedBytes = 0;
+
+ DownloadPartialBlock(
+ Context,
+ RemoteResult,
+ Info,
+ DownloadElapsedSeconds,
+ BlockDescription,
+ BlockExistsInCache,
+ BlockRangeDescriptors,
+ BlockRangeIndexStart,
+ BlockRangeCount,
+ [&](IoBuffer&& Buffer,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) {
+ uint64_t BlockPartSize = Buffer.GetSize();
+ DownloadedBytes += BlockPartSize;
+
+ Info.AttachmentBlockRangeBytesDownloaded.fetch_add(BlockPartSize);
+ Info.AttachmentBlocksRangesDownloaded++;
+
+ AttachmentsWriteLatch.AddCount(1);
+ Context.WorkerPool.ScheduleWork(
+ [&AttachmentsWriteLatch,
+ &Context,
+ &AttachmentsDownloadLatch,
+ &RemoteResult,
+ &Info,
+ &LoadAttachmentsTimer,
+ &DownloadStartMS,
+ BlockDescription,
+ BlockExistsInCache,
+ BlockRangeDescriptors,
+ BlockRangeStartIndex,
+ &AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ RetriesLeft,
+ BlockPayload = std::move(Buffer),
+ OffsetAndLengths =
+ std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), OffsetAndLengths.end())]() {
+ auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); });
+ try
+ {
+ ZEN_ASSERT(BlockPayload.Size() > 0);
+
+ size_t RangeCount = OffsetAndLengths.size();
+ for (size_t RangeOffset = 0; RangeOffset < RangeCount; RangeOffset++)
+ {
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange =
+ BlockRangeDescriptors[BlockRangeStartIndex + RangeOffset];
+ const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengths[RangeOffset];
+ IoBuffer BlockRangeBuffer(BlockPayload, OffsetAndLength.first, OffsetAndLength.second);
+
+ std::vector<IoBuffer> WriteAttachmentBuffers;
+ std::vector<IoHash> WriteRawHashes;
+
+ uint64_t PotentialSize = 0;
+ uint64_t UsedSize = 0;
+ uint64_t BlockPartSize = BlockRangeBuffer.GetSize();
+
+ uint32_t OffsetInBlock = 0;
+ for (uint32_t ChunkBlockIndex = BlockRange.ChunkBlockIndexStart;
+ ChunkBlockIndex < BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount;
+ ChunkBlockIndex++)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+
+ const uint32_t ChunkCompressedSize =
+ BlockDescription.ChunkCompressedLengths[ChunkBlockIndex];
+ const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex];
+
+ if (auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(ChunkHash);
+ ChunkIndexIt != AllNeededPartialChunkHashesLookup.end())
+ {
+ if (!ChunkDownloadedFlags[ChunkIndexIt->second])
+ {
+ IoHash VerifyChunkHash;
+ uint64_t VerifyChunkSize;
+ CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(
+ SharedBuffer(IoBuffer(BlockRangeBuffer, OffsetInBlock, ChunkCompressedSize)),
+ VerifyChunkHash,
+ VerifyChunkSize);
+
+ std::string ErrorString;
+
+ if (!CompressedChunk)
+ {
+ ErrorString = fmt::format(
+ "Chunk at {},{} in block attachment '{}' is not a valid compressed buffer",
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockDescription.BlockHash);
+ }
+ else if (VerifyChunkHash != ChunkHash)
+ {
+ ErrorString = fmt::format(
+ "Chunk at {},{} in block attachment '{}' has mismatching hash, expected "
+ "{}, got {}",
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockDescription.BlockHash,
+ ChunkHash,
+ VerifyChunkHash);
+ }
+ else if (VerifyChunkSize != BlockDescription.ChunkRawLengths[ChunkBlockIndex])
+ {
+ ErrorString = fmt::format(
+ "Chunk at {},{} in block attachment '{}' has mismatching raw size, "
+ "expected {}, "
+ "got {}",
+ OffsetInBlock,
+ ChunkCompressedSize,
+ BlockDescription.BlockHash,
+ BlockDescription.ChunkRawLengths[ChunkBlockIndex],
+ VerifyChunkSize);
+ }
+
+ if (!ErrorString.empty())
+ {
+ if (RetriesLeft > 0)
+ {
+ ReportMessage(Context.OptionalJobContext,
+ fmt::format("{}, retrying download", ErrorString));
+ return DownloadAndSavePartialBlock(Context,
+ AttachmentsDownloadLatch,
+ AttachmentsWriteLatch,
+ RemoteResult,
+ Info,
+ LoadAttachmentsTimer,
+ DownloadStartMS,
+ BlockDescription,
+ BlockExistsInCache,
+ BlockRangeDescriptors,
+ BlockRangeStartIndex,
+ RangeCount,
+ AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ RetriesLeft - 1);
+ }
+
+ ReportMessage(Context.OptionalJobContext, ErrorString);
+ Info.MissingAttachmentCount.fetch_add(1);
+ if (!Context.IgnoreMissingAttachments)
+ {
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::NotFound),
+ "Malformed chunk block",
+ ErrorString);
+ }
+ }
+ else
+ {
+ bool Expected = false;
+ if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected,
+ true))
+ {
+ WriteAttachmentBuffers.emplace_back(
+ CompressedChunk.GetCompressed().Flatten().AsIoBuffer());
+ WriteRawHashes.emplace_back(ChunkHash);
+ PotentialSize += WriteAttachmentBuffers.back().GetSize();
+ }
+ }
+ }
+ }
+ OffsetInBlock += ChunkCompressedSize;
+ }
+
+ if (!WriteAttachmentBuffers.empty())
+ {
+ std::vector<CidStore::InsertResult> Results =
+ Context.ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes);
+ for (size_t Index = 0; Index < Results.size(); Index++)
+ {
+ const CidStore::InsertResult& Result = Results[Index];
+ if (Result.New)
+ {
+ Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize());
+ Info.AttachmentsStored.fetch_add(1);
+ UsedSize += WriteAttachmentBuffers[Index].GetSize();
+ }
+ }
+ if (UsedSize < BlockPartSize)
+ {
+ ZEN_DEBUG(
+ "Used {} (skipping {}) out of {} for block {} range {}, {} ({} %) (use of matching "
+ "{}%)",
+ NiceBytes(UsedSize),
+ NiceBytes(BlockPartSize - UsedSize),
+ NiceBytes(BlockPartSize),
+ BlockDescription.BlockHash,
+ BlockRange.RangeStart,
+ BlockRange.RangeLength,
+ (100 * UsedSize) / BlockPartSize,
+ PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0);
+ }
+ }
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError),
+ fmt::format("Failed saving {} ranges from block attachment {}",
+ OffsetAndLengths.size(),
+ BlockDescription.BlockHash),
+ Ex.what());
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+ });
+ if (!RemoteResult.IsError())
+ {
+ ZEN_DEBUG("Loaded {} ranges from block attachment '{}' in {} ({})",
+ BlockRangeCount,
+ BlockDescription.BlockHash,
+ NiceTimeSpanMs(static_cast<uint64_t>(DownloadElapsedSeconds * 1000)),
+ NiceBytes(DownloadedBytes));
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::InternalServerError),
+ fmt::format("Failed to download block attachment {} ranges", BlockDescription.BlockHash),
Ex.what());
}
},
WorkerThreadPool::EMode::EnableBacklog);
};
- void DownloadAndSaveAttachment(CidStore& ChunkStore,
- RemoteProjectStore& RemoteStore,
- bool IgnoreMissingAttachments,
- JobContext* OptionalContext,
- WorkerThreadPool& NetworkWorkerPool,
- WorkerThreadPool& WorkerPool,
+ void DownloadAndSaveAttachment(LoadOplogContext& Context,
Latch& AttachmentsDownloadLatch,
Latch& AttachmentsWriteLatch,
AsyncRemoteResult& RemoteResult,
@@ -640,19 +1093,15 @@ namespace remotestore_impl {
const IoHash& RawHash)
{
AttachmentsDownloadLatch.AddCount(1);
- NetworkWorkerPool.ScheduleWork(
- [&RemoteStore,
- &ChunkStore,
- &WorkerPool,
+ Context.NetworkWorkerPool.ScheduleWork(
+ [&Context,
&RemoteResult,
&AttachmentsDownloadLatch,
&AttachmentsWriteLatch,
RawHash,
&LoadAttachmentsTimer,
&DownloadStartMS,
- &Info,
- IgnoreMissingAttachments,
- OptionalContext]() {
+ &Info]() {
ZEN_TRACE_CPU("DownloadAttachment");
auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); });
@@ -664,43 +1113,52 @@ namespace remotestore_impl {
{
uint64_t Unset = (std::uint64_t)-1;
DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs());
- RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash);
- if (AttachmentResult.ErrorCode)
+ IoBuffer BlobBuffer;
+ if (Context.OptionalCache)
{
- ReportMessage(OptionalContext,
- fmt::format("Failed to download large attachment {}: '{}', error code : {}",
- RawHash,
- AttachmentResult.Reason,
- AttachmentResult.ErrorCode));
- Info.MissingAttachmentCount.fetch_add(1);
- if (!IgnoreMissingAttachments)
+ BlobBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, RawHash);
+ }
+ if (!BlobBuffer)
+ {
+ RemoteProjectStore::LoadAttachmentResult AttachmentResult = Context.RemoteStore.LoadAttachment(RawHash);
+ if (AttachmentResult.ErrorCode)
{
- RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text);
+ ReportMessage(Context.OptionalJobContext,
+ fmt::format("Failed to download large attachment {}: '{}', error code : {}",
+ RawHash,
+ AttachmentResult.Reason,
+ AttachmentResult.ErrorCode));
+ Info.MissingAttachmentCount.fetch_add(1);
+ if (!Context.IgnoreMissingAttachments)
+ {
+ RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text);
+ }
+ return;
+ }
+ BlobBuffer = std::move(AttachmentResult.Bytes);
+ ZEN_DEBUG("Loaded large attachment '{}' in {} ({})",
+ RawHash,
+ NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000)),
+ NiceBytes(BlobBuffer.GetSize()));
+ if (Context.OptionalCache && Context.PopulateCache)
+ {
+ Context.OptionalCache->PutBuildBlob(Context.CacheBuildId,
+ RawHash,
+ BlobBuffer.GetContentType(),
+ CompositeBuffer(SharedBuffer(BlobBuffer)));
}
- return;
}
- uint64_t AttachmentSize = AttachmentResult.Bytes.GetSize();
- ZEN_INFO("Loaded large attachment '{}' in {} ({})",
- RawHash,
- NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000)),
- NiceBytes(AttachmentSize));
- Info.AttachmentsDownloaded.fetch_add(1);
if (RemoteResult.IsError())
{
return;
}
+ uint64_t AttachmentSize = BlobBuffer.GetSize();
+ Info.AttachmentsDownloaded.fetch_add(1);
Info.AttachmentBytesDownloaded.fetch_add(AttachmentSize);
AttachmentsWriteLatch.AddCount(1);
- WorkerPool.ScheduleWork(
- [&AttachmentsWriteLatch,
- &RemoteResult,
- &Info,
- &ChunkStore,
- RawHash,
- AttachmentSize,
- Bytes = std::move(AttachmentResult.Bytes),
- OptionalContext]() {
+ Context.WorkerPool.ScheduleWork(
+ [&Context, &AttachmentsWriteLatch, &RemoteResult, &Info, RawHash, AttachmentSize, Bytes = std::move(BlobBuffer)]() {
ZEN_TRACE_CPU("WriteAttachment");
auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); });
@@ -710,7 +1168,7 @@ namespace remotestore_impl {
}
try
{
- CidStore::InsertResult InsertResult = ChunkStore.AddChunk(Bytes, RawHash);
+ CidStore::InsertResult InsertResult = Context.ChunkStore.AddChunk(Bytes, RawHash);
if (InsertResult.New)
{
Info.AttachmentBytesStored.fetch_add(AttachmentSize);
@@ -1126,7 +1584,9 @@ namespace remotestore_impl {
uint64_t PartialTransferWallTimeMS = Timer.GetElapsedTimeMs();
ReportProgress(OptionalContext,
"Saving attachments"sv,
- fmt::format("{} remaining... {}", Remaining, GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)),
+ fmt::format("{} remaining... {}",
+ Remaining,
+ GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, PartialTransferWallTimeMS)),
AttachmentsToSave,
Remaining);
}
@@ -1135,7 +1595,7 @@ namespace remotestore_impl {
{
ReportProgress(OptionalContext,
"Saving attachments"sv,
- fmt::format("{}", GetStats(RemoteStore.GetStats(), ElapsedTimeMS)),
+ fmt::format("{}", GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, ElapsedTimeMS)),
AttachmentsToSave,
0);
}
@@ -1146,7 +1606,7 @@ namespace remotestore_impl {
LargeAttachmentCountToUpload,
BulkAttachmentCountToUpload,
NiceTimeSpanMs(ElapsedTimeMS),
- GetStats(RemoteStore.GetStats(), ElapsedTimeMS)));
+ GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, ElapsedTimeMS)));
}
} // namespace remotestore_impl
@@ -1224,35 +1684,7 @@ BuildContainer(CidStore& ChunkStore,
{
using namespace std::literals;
- class JobContextLogOutput : public OperationLogOutput
- {
- public:
- JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {}
- virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override
- {
- ZEN_UNUSED(LogLevel);
- if (m_OptionalContext)
- {
- fmt::basic_memory_buffer<char, 250> MessageBuffer;
- fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args);
- remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size()));
- }
- }
-
- virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); }
- virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); }
- virtual uint32_t GetProgressUpdateDelayMS() override { return 0; }
- virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override
- {
- ZEN_UNUSED(InSubTask);
- return nullptr;
- }
-
- private:
- JobContext* m_OptionalContext;
- };
-
- std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<JobContextLogOutput>(OptionalContext));
+ std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(OptionalContext));
size_t OpCount = 0;
@@ -1783,31 +2215,36 @@ BuildContainer(CidStore& ChunkStore,
}
ResolveAttachmentsLatch.CountDown();
- while (!ResolveAttachmentsLatch.Wait(1000))
{
- ptrdiff_t Remaining = ResolveAttachmentsLatch.Remaining();
- if (remotestore_impl::IsCancelled(OptionalContext))
+ ptrdiff_t AttachmentCountToUseForProgress = ResolveAttachmentsLatch.Remaining();
+ while (!ResolveAttachmentsLatch.Wait(1000))
{
- RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
- remotestore_impl::ReportMessage(OptionalContext,
- fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason()));
- while (!ResolveAttachmentsLatch.Wait(1000))
+ ptrdiff_t Remaining = ResolveAttachmentsLatch.Remaining();
+ if (remotestore_impl::IsCancelled(OptionalContext))
{
- Remaining = ResolveAttachmentsLatch.Remaining();
- remotestore_impl::ReportProgress(OptionalContext,
- "Resolving attachments"sv,
- fmt::format("Aborting, {} attachments remaining...", Remaining),
- UploadAttachments.size(),
- Remaining);
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
+ remotestore_impl::ReportMessage(
+ OptionalContext,
+ fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason()));
+ while (!ResolveAttachmentsLatch.Wait(1000))
+ {
+ Remaining = ResolveAttachmentsLatch.Remaining();
+ remotestore_impl::ReportProgress(OptionalContext,
+ "Resolving attachments"sv,
+ fmt::format("Aborting, {} attachments remaining...", Remaining),
+ UploadAttachments.size(),
+ Remaining);
+ }
+ remotestore_impl::ReportProgress(OptionalContext, "Resolving attachments"sv, "Aborted"sv, UploadAttachments.size(), 0);
+ return {};
}
- remotestore_impl::ReportProgress(OptionalContext, "Resolving attachments"sv, "Aborted"sv, UploadAttachments.size(), 0);
- return {};
+ AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress);
+ remotestore_impl::ReportProgress(OptionalContext,
+ "Resolving attachments"sv,
+ fmt::format("{} remaining...", Remaining),
+ AttachmentCountToUseForProgress,
+ Remaining);
}
- remotestore_impl::ReportProgress(OptionalContext,
- "Resolving attachments"sv,
- fmt::format("{} remaining...", Remaining),
- UploadAttachments.size(),
- Remaining);
}
if (UploadAttachments.size() > 0)
{
@@ -2010,14 +2447,13 @@ BuildContainer(CidStore& ChunkStore,
AsyncOnBlock,
RemoteResult);
ComposedBlocks++;
+ // Worker will set Blocks[BlockIndex] = Block (including ChunkRawHashes) under shared lock
}
else
{
ZEN_INFO("Bulk group {} attachments", ChunkCount);
OnBlockChunks(std::move(ChunksInBlock));
- }
- {
- // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ // We can share the lock as we are not resizing the vector and only touch our own index
RwLock::SharedLockScope _(BlocksLock);
Blocks[BlockIndex].ChunkRawHashes = std::move(ChunkRawHashes);
}
@@ -2195,12 +2631,14 @@ BuildContainer(CidStore& ChunkStore,
0);
}
- remotestore_impl::ReportMessage(OptionalContext,
- fmt::format("Built oplog and collected {} attachments from {} ops into {} blocks and in {}",
- ChunkAssembleCount,
- TotalOpCount,
- GeneratedBlockCount,
- NiceTimeSpanMs(static_cast<uint64_t>(Timer.GetElapsedTimeMs()))));
+ remotestore_impl::ReportMessage(
+ OptionalContext,
+ fmt::format("Built oplog and collected {} attachments from {} ops into {} blocks and {} loose attachments in {}",
+ ChunkAssembleCount,
+ TotalOpCount,
+ GeneratedBlockCount,
+ LargeChunkHashes.size(),
+ NiceTimeSpanMs(static_cast<uint64_t>(Timer.GetElapsedTimeMs()))));
if (remotestore_impl::IsCancelled(OptionalContext))
{
@@ -2752,30 +3190,32 @@ SaveOplog(CidStore& ChunkStore,
remotestore_impl::LogRemoteStoreStatsDetails(RemoteStore.GetStats());
- remotestore_impl::ReportMessage(OptionalContext,
- fmt::format("Saved oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}) {}",
- RemoteStoreInfo.ContainerName,
- RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE",
- NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)),
- NiceBytes(Info.OplogSizeBytes),
- Info.AttachmentBlocksUploaded.load(),
- NiceBytes(Info.AttachmentBlockBytesUploaded.load()),
- Info.AttachmentsUploaded.load(),
- NiceBytes(Info.AttachmentBytesUploaded.load()),
- remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS)));
+ remotestore_impl::ReportMessage(
+ OptionalContext,
+ fmt::format("Saved oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}) {}",
+ RemoteStoreInfo.ContainerName,
+ RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE",
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)),
+ NiceBytes(Info.OplogSizeBytes),
+ Info.AttachmentBlocksUploaded.load(),
+ NiceBytes(Info.AttachmentBlockBytesUploaded.load()),
+ Info.AttachmentsUploaded.load(),
+ NiceBytes(Info.AttachmentBytesUploaded.load()),
+ remotestore_impl::GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, TransferWallTimeMS)));
return Result;
};
RemoteProjectStore::Result
-ParseOplogContainer(const CbObject& ContainerObject,
- const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments,
- const std::function<bool(const IoHash& RawHash)>& HasAttachment,
- const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
- const std::function<void(const IoHash& RawHash)>& OnNeedAttachment,
- const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment,
- CbObject& OutOplogSection,
- JobContext* OptionalContext)
+ParseOplogContainer(
+ const CbObject& ContainerObject,
+ const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(ThinChunkBlockDescription&& ThinBlockDescription, std::vector<uint32_t>&& NeededChunkIndexes)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment,
+ const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment,
+ CbObject& OutOplogSection,
+ JobContext* OptionalContext)
{
using namespace std::literals;
@@ -2801,22 +3241,43 @@ ParseOplogContainer(const CbObject& ContainerObject,
"Section has unexpected data type",
"Failed to save oplog container"};
}
- std::unordered_set<IoHash, IoHash::Hasher> OpsAttachments;
+ std::unordered_set<IoHash, IoHash::Hasher> NeededAttachments;
{
CbArrayView OpsArray = OutOplogSection["ops"sv].AsArrayView();
+
+ size_t OpCount = OpsArray.Num();
+ size_t OpsCompleteCount = 0;
+
+ remotestore_impl::ReportMessage(OptionalContext, fmt::format("Scanning {} ops for attachments", OpCount));
+
for (CbFieldView OpEntry : OpsArray)
{
- OpEntry.IterateAttachments([&](CbFieldView FieldView) { OpsAttachments.insert(FieldView.AsAttachment()); });
+ OpEntry.IterateAttachments([&](CbFieldView FieldView) { NeededAttachments.insert(FieldView.AsAttachment()); });
if (remotestore_impl::IsCancelled(OptionalContext))
{
return RemoteProjectStore::Result{.ErrorCode = gsl::narrow<int>(HttpResponseCode::OK),
.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0,
.Reason = "Operation cancelled"};
}
+ OpsCompleteCount++;
+ if ((OpsCompleteCount & 4095) == 0)
+ {
+ remotestore_impl::ReportProgress(
+ OptionalContext,
+ "Scanning oplog"sv,
+ fmt::format("{} attachments found, {} ops remaining...", NeededAttachments.size(), OpCount - OpsCompleteCount),
+ OpCount,
+ OpCount - OpsCompleteCount);
+ }
}
+ remotestore_impl::ReportProgress(OptionalContext,
+ "Scanning oplog"sv,
+ fmt::format("{} attachments found", NeededAttachments.size()),
+ OpCount,
+ OpCount - OpsCompleteCount);
}
{
- std::vector<IoHash> ReferencedAttachments(OpsAttachments.begin(), OpsAttachments.end());
+ std::vector<IoHash> ReferencedAttachments(NeededAttachments.begin(), NeededAttachments.end());
OnReferencedAttachments(ReferencedAttachments);
}
@@ -2827,24 +3288,41 @@ ParseOplogContainer(const CbObject& ContainerObject,
.Reason = "Operation cancelled"};
}
- remotestore_impl::ReportMessage(OptionalContext, fmt::format("Oplog references {} attachments", OpsAttachments.size()));
+ remotestore_impl::ReportMessage(OptionalContext, fmt::format("Oplog references {} attachments", NeededAttachments.size()));
CbArrayView ChunkedFilesArray = ContainerObject["chunkedfiles"sv].AsArrayView();
for (CbFieldView ChunkedFileField : ChunkedFilesArray)
{
CbObjectView ChunkedFileView = ChunkedFileField.AsObjectView();
IoHash RawHash = ChunkedFileView["rawhash"sv].AsHash();
- if (OpsAttachments.contains(RawHash) && (!HasAttachment(RawHash)))
+ if (NeededAttachments.erase(RawHash) == 1)
{
- ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView);
+ if (!HasAttachment(RawHash))
+ {
+ ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView);
+
+ size_t NeededChunkAttachmentCount = 0;
- OnReferencedAttachments(Chunked.ChunkHashes);
- OpsAttachments.insert(Chunked.ChunkHashes.begin(), Chunked.ChunkHashes.end());
- OnChunkedAttachment(Chunked);
- ZEN_INFO("Requesting chunked attachment '{}' ({}) built from {} chunks",
- Chunked.RawHash,
- NiceBytes(Chunked.RawSize),
- Chunked.ChunkHashes.size());
+ OnReferencedAttachments(Chunked.ChunkHashes);
+ for (const IoHash& ChunkHash : Chunked.ChunkHashes)
+ {
+ if (!HasAttachment(ChunkHash))
+ {
+ if (NeededAttachments.insert(ChunkHash).second)
+ {
+ NeededChunkAttachmentCount++;
+ }
+ }
+ }
+ OnChunkedAttachment(Chunked);
+
+ remotestore_impl::ReportMessage(OptionalContext,
+ fmt::format("Requesting chunked attachment '{}' ({}) built from {} chunks, need {} chunks",
+ Chunked.RawHash,
+ NiceBytes(Chunked.RawSize),
+ Chunked.ChunkHashes.size(),
+ NeededChunkAttachmentCount));
+ }
}
if (remotestore_impl::IsCancelled(OptionalContext))
{
@@ -2854,6 +3332,8 @@ ParseOplogContainer(const CbObject& ContainerObject,
}
}
+ std::vector<ThinChunkBlockDescription> ThinBlocksDescriptions;
+
size_t NeedBlockCount = 0;
CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView();
for (CbFieldView BlockField : BlocksArray)
@@ -2863,45 +3343,38 @@ ParseOplogContainer(const CbObject& ContainerObject,
CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView();
- std::vector<IoHash> NeededChunks;
- NeededChunks.reserve(ChunksArray.Num());
- if (BlockHash == IoHash::Zero)
+ std::vector<IoHash> ChunkHashes;
+ ChunkHashes.reserve(ChunksArray.Num());
+ for (CbFieldView ChunkField : ChunksArray)
{
- for (CbFieldView ChunkField : ChunksArray)
- {
- IoHash ChunkHash = ChunkField.AsBinaryAttachment();
- if (OpsAttachments.erase(ChunkHash) == 1)
- {
- if (!HasAttachment(ChunkHash))
- {
- NeededChunks.emplace_back(ChunkHash);
- }
- }
- }
+ ChunkHashes.push_back(ChunkField.AsHash());
}
- else
+ ThinBlocksDescriptions.push_back(ThinChunkBlockDescription{.BlockHash = BlockHash, .ChunkRawHashes = std::move(ChunkHashes)});
+ }
+
+ for (ThinChunkBlockDescription& ThinBlockDescription : ThinBlocksDescriptions)
+ {
+ std::vector<uint32_t> NeededBlockChunkIndexes;
+ for (uint32_t ChunkIndex = 0; ChunkIndex < ThinBlockDescription.ChunkRawHashes.size(); ChunkIndex++)
{
- for (CbFieldView ChunkField : ChunksArray)
+ const IoHash& ChunkHash = ThinBlockDescription.ChunkRawHashes[ChunkIndex];
+ if (NeededAttachments.erase(ChunkHash) == 1)
{
- const IoHash ChunkHash = ChunkField.AsHash();
- if (OpsAttachments.erase(ChunkHash) == 1)
+ if (!HasAttachment(ChunkHash))
{
- if (!HasAttachment(ChunkHash))
- {
- NeededChunks.emplace_back(ChunkHash);
- }
+ NeededBlockChunkIndexes.push_back(ChunkIndex);
}
}
}
-
- if (!NeededChunks.empty())
+ if (!NeededBlockChunkIndexes.empty())
{
- OnNeedBlock(BlockHash, std::move(NeededChunks));
- if (BlockHash != IoHash::Zero)
+ if (ThinBlockDescription.BlockHash != IoHash::Zero)
{
NeedBlockCount++;
}
+ OnNeedBlock(std::move(ThinBlockDescription), std::move(NeededBlockChunkIndexes));
}
+
if (remotestore_impl::IsCancelled(OptionalContext))
{
return RemoteProjectStore::Result{.ErrorCode = gsl::narrow<int>(HttpResponseCode::OK),
@@ -2909,6 +3382,7 @@ ParseOplogContainer(const CbObject& ContainerObject,
.Reason = "Operation cancelled"};
}
}
+
remotestore_impl::ReportMessage(OptionalContext,
fmt::format("Requesting {} of {} attachment blocks", NeedBlockCount, BlocksArray.Num()));
@@ -2918,7 +3392,7 @@ ParseOplogContainer(const CbObject& ContainerObject,
{
IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment();
- if (OpsAttachments.erase(AttachmentHash) == 1)
+ if (NeededAttachments.erase(AttachmentHash) == 1)
{
if (!HasAttachment(AttachmentHash))
{
@@ -2941,14 +3415,15 @@ ParseOplogContainer(const CbObject& ContainerObject,
}
RemoteProjectStore::Result
-SaveOplogContainer(ProjectStore::Oplog& Oplog,
- const CbObject& ContainerObject,
- const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments,
- const std::function<bool(const IoHash& RawHash)>& HasAttachment,
- const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
- const std::function<void(const IoHash& RawHash)>& OnNeedAttachment,
- const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment,
- JobContext* OptionalContext)
+SaveOplogContainer(
+ ProjectStore::Oplog& Oplog,
+ const CbObject& ContainerObject,
+ const std::function<void(std::span<IoHash> RawHashes)>& OnReferencedAttachments,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(ThinChunkBlockDescription&& ThinBlockDescription, std::vector<uint32_t>&& NeededChunkIndexes)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment,
+ const std::function<void(const ChunkedInfo&)>& OnChunkedAttachment,
+ JobContext* OptionalContext)
{
using namespace std::literals;
@@ -2972,18 +3447,12 @@ SaveOplogContainer(ProjectStore::Oplog& Oplog,
}
RemoteProjectStore::Result
-LoadOplog(CidStore& ChunkStore,
- RemoteProjectStore& RemoteStore,
- ProjectStore::Oplog& Oplog,
- WorkerThreadPool& NetworkWorkerPool,
- WorkerThreadPool& WorkerPool,
- bool ForceDownload,
- bool IgnoreMissingAttachments,
- bool CleanOplog,
- JobContext* OptionalContext)
+LoadOplog(LoadOplogContext&& Context)
{
using namespace std::literals;
+ std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(Context.OptionalJobContext));
+
remotestore_impl::DownloadInfo Info;
Stopwatch Timer;
@@ -2991,25 +3460,25 @@ LoadOplog(CidStore& ChunkStore,
std::unordered_set<IoHash, IoHash::Hasher> Attachments;
uint64_t BlockCountToDownload = 0;
- RemoteProjectStore::RemoteStoreInfo RemoteStoreInfo = RemoteStore.GetInfo();
- remotestore_impl::ReportMessage(OptionalContext, fmt::format("Loading oplog container '{}'", RemoteStoreInfo.ContainerName));
+ RemoteProjectStore::RemoteStoreInfo RemoteStoreInfo = Context.RemoteStore.GetInfo();
+ remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Loading oplog container '{}'", RemoteStoreInfo.ContainerName));
uint64_t TransferWallTimeMS = 0;
Stopwatch LoadContainerTimer;
- RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer();
+ RemoteProjectStore::LoadContainerResult LoadContainerResult = Context.RemoteStore.LoadContainer();
TransferWallTimeMS += LoadContainerTimer.GetElapsedTimeMs();
if (LoadContainerResult.ErrorCode)
{
remotestore_impl::ReportMessage(
- OptionalContext,
+ Context.OptionalJobContext,
fmt::format("Failed to load oplog container: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode));
return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode,
.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0,
.Reason = LoadContainerResult.Reason,
.Text = LoadContainerResult.Text};
}
- remotestore_impl::ReportMessage(OptionalContext,
+ remotestore_impl::ReportMessage(Context.OptionalJobContext,
fmt::format("Loaded container in {} ({})",
NiceTimeSpanMs(static_cast<uint64_t>(LoadContainerResult.ElapsedSeconds * 1000)),
NiceBytes(LoadContainerResult.ContainerObject.GetSize())));
@@ -3023,22 +3492,27 @@ LoadOplog(CidStore& ChunkStore,
Stopwatch LoadAttachmentsTimer;
std::atomic_uint64_t DownloadStartMS = (std::uint64_t)-1;
- auto HasAttachment = [&Oplog, &ChunkStore, ForceDownload](const IoHash& RawHash) {
- if (ForceDownload)
+ auto HasAttachment = [&Context](const IoHash& RawHash) {
+ if (Context.ForceDownload)
{
return false;
}
- if (ChunkStore.ContainsChunk(RawHash))
+ if (Context.ChunkStore.ContainsChunk(RawHash))
{
return true;
}
return false;
};
- auto OnNeedBlock = [&RemoteStore,
- &ChunkStore,
- &NetworkWorkerPool,
- &WorkerPool,
+ struct NeededBlockDownload
+ {
+ ThinChunkBlockDescription ThinBlockDescription;
+ std::vector<uint32_t> NeededChunkIndexes;
+ };
+
+ std::vector<NeededBlockDownload> NeededBlockDownloads;
+
+ auto OnNeedBlock = [&Context,
&AttachmentsDownloadLatch,
&AttachmentsWriteLatch,
&AttachmentCount,
@@ -3047,8 +3521,8 @@ LoadOplog(CidStore& ChunkStore,
&Info,
&LoadAttachmentsTimer,
&DownloadStartMS,
- IgnoreMissingAttachments,
- OptionalContext](const IoHash& BlockHash, std::vector<IoHash>&& Chunks) {
+ &NeededBlockDownloads](ThinChunkBlockDescription&& ThinBlockDescription,
+ std::vector<uint32_t>&& NeededChunkIndexes) {
if (RemoteResult.IsError())
{
return;
@@ -3056,47 +3530,26 @@ LoadOplog(CidStore& ChunkStore,
BlockCountToDownload++;
AttachmentCount.fetch_add(1);
- if (BlockHash == IoHash::Zero)
- {
- DownloadAndSaveBlockChunks(ChunkStore,
- RemoteStore,
- IgnoreMissingAttachments,
- OptionalContext,
- NetworkWorkerPool,
- WorkerPool,
+ if (ThinBlockDescription.BlockHash == IoHash::Zero)
+ {
+ DownloadAndSaveBlockChunks(Context,
AttachmentsDownloadLatch,
AttachmentsWriteLatch,
RemoteResult,
Info,
LoadAttachmentsTimer,
DownloadStartMS,
- Chunks);
+ std::move(ThinBlockDescription),
+ std::move(NeededChunkIndexes));
}
else
{
- DownloadAndSaveBlock(ChunkStore,
- RemoteStore,
- IgnoreMissingAttachments,
- OptionalContext,
- NetworkWorkerPool,
- WorkerPool,
- AttachmentsDownloadLatch,
- AttachmentsWriteLatch,
- RemoteResult,
- Info,
- LoadAttachmentsTimer,
- DownloadStartMS,
- BlockHash,
- Chunks,
- 3);
+ NeededBlockDownloads.push_back(NeededBlockDownload{.ThinBlockDescription = std::move(ThinBlockDescription),
+ .NeededChunkIndexes = std::move(NeededChunkIndexes)});
}
};
- auto OnNeedAttachment = [&RemoteStore,
- &Oplog,
- &ChunkStore,
- &NetworkWorkerPool,
- &WorkerPool,
+ auto OnNeedAttachment = [&Context,
&AttachmentsDownloadLatch,
&AttachmentsWriteLatch,
&RemoteResult,
@@ -3104,9 +3557,7 @@ LoadOplog(CidStore& ChunkStore,
&AttachmentCount,
&LoadAttachmentsTimer,
&DownloadStartMS,
- &Info,
- IgnoreMissingAttachments,
- OptionalContext](const IoHash& RawHash) {
+ &Info](const IoHash& RawHash) {
if (!Attachments.insert(RawHash).second)
{
return;
@@ -3116,12 +3567,7 @@ LoadOplog(CidStore& ChunkStore,
return;
}
AttachmentCount.fetch_add(1);
- DownloadAndSaveAttachment(ChunkStore,
- RemoteStore,
- IgnoreMissingAttachments,
- OptionalContext,
- NetworkWorkerPool,
- WorkerPool,
+ DownloadAndSaveAttachment(Context,
AttachmentsDownloadLatch,
AttachmentsWriteLatch,
RemoteResult,
@@ -3132,18 +3578,13 @@ LoadOplog(CidStore& ChunkStore,
};
std::vector<ChunkedInfo> FilesToDechunk;
- auto OnChunkedAttachment = [&Oplog, &ChunkStore, &FilesToDechunk, ForceDownload](const ChunkedInfo& Chunked) {
- if (ForceDownload || !ChunkStore.ContainsChunk(Chunked.RawHash))
- {
- FilesToDechunk.push_back(Chunked);
- }
- };
+ auto OnChunkedAttachment = [&FilesToDechunk](const ChunkedInfo& Chunked) { FilesToDechunk.push_back(Chunked); };
- auto OnReferencedAttachments = [&Oplog](std::span<IoHash> RawHashes) { Oplog.CaptureAddedAttachments(RawHashes); };
+ auto OnReferencedAttachments = [&Context](std::span<IoHash> RawHashes) { Context.Oplog.CaptureAddedAttachments(RawHashes); };
// Make sure we retain any attachments we download before writing the oplog
- Oplog.EnableUpdateCapture();
- auto _ = MakeGuard([&Oplog]() { Oplog.DisableUpdateCapture(); });
+ Context.Oplog.EnableUpdateCapture();
+ auto _ = MakeGuard([&Context]() { Context.Oplog.DisableUpdateCapture(); });
CbObject OplogSection;
RemoteProjectStore::Result Result = ParseOplogContainer(LoadContainerResult.ContainerObject,
@@ -3153,40 +3594,268 @@ LoadOplog(CidStore& ChunkStore,
OnNeedAttachment,
OnChunkedAttachment,
OplogSection,
- OptionalContext);
+ Context.OptionalJobContext);
if (Result.ErrorCode != 0)
{
RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
}
- remotestore_impl::ReportMessage(OptionalContext,
+ remotestore_impl::ReportMessage(Context.OptionalJobContext,
fmt::format("Parsed oplog in {}, found {} attachments, {} blocks and {} chunked files to download",
NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)),
Attachments.size(),
BlockCountToDownload,
FilesToDechunk.size()));
- AttachmentsDownloadLatch.CountDown();
- while (!AttachmentsDownloadLatch.Wait(1000))
+ std::vector<IoHash> BlockHashes;
+ std::vector<IoHash> AllNeededChunkHashes;
+ BlockHashes.reserve(NeededBlockDownloads.size());
+ for (const NeededBlockDownload& BlockDownload : NeededBlockDownloads)
{
- ptrdiff_t Remaining = AttachmentsDownloadLatch.Remaining();
- if (remotestore_impl::IsCancelled(OptionalContext))
+ BlockHashes.push_back(BlockDownload.ThinBlockDescription.BlockHash);
+ for (uint32_t ChunkIndex : BlockDownload.NeededChunkIndexes)
{
- if (!RemoteResult.IsError())
+ AllNeededChunkHashes.push_back(BlockDownload.ThinBlockDescription.ChunkRawHashes[ChunkIndex]);
+ }
+ }
+
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> AllNeededPartialChunkHashesLookup = BuildHashLookup(AllNeededChunkHashes);
+ std::vector<std::atomic<bool>> ChunkDownloadedFlags(AllNeededChunkHashes.size());
+ std::vector<bool> DownloadedViaLegacyChunkFlag(AllNeededChunkHashes.size(), false);
+ ChunkBlockAnalyser::BlockResult PartialBlocksResult;
+
+ remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Fetching descriptions for {} blocks", BlockHashes.size()));
+
+ RemoteProjectStore::GetBlockDescriptionsResult BlockDescriptions =
+ Context.RemoteStore.GetBlockDescriptions(BlockHashes, Context.OptionalCache, Context.CacheBuildId);
+
+ remotestore_impl::ReportMessage(Context.OptionalJobContext,
+ fmt::format("GetBlockDescriptions took {}. Found {} blocks",
+ NiceTimeSpanMs(uint64_t(BlockDescriptions.ElapsedSeconds * 1000)),
+ BlockDescriptions.Blocks.size()));
+
+ std::vector<IoHash> BlocksWithDescription;
+ BlocksWithDescription.reserve(BlockDescriptions.Blocks.size());
+ for (const ChunkBlockDescription& BlockDescription : BlockDescriptions.Blocks)
+ {
+ BlocksWithDescription.push_back(BlockDescription.BlockHash);
+ }
+ {
+ auto WantIt = NeededBlockDownloads.begin();
+ auto FindIt = BlockDescriptions.Blocks.begin();
+ while (WantIt != NeededBlockDownloads.end())
+ {
+ if (FindIt == BlockDescriptions.Blocks.end())
{
- RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
+ // Fall back to full download as we can't get enough information about the block
+ DownloadAndSaveBlock(Context,
+ AttachmentsDownloadLatch,
+ AttachmentsWriteLatch,
+ RemoteResult,
+ Info,
+ LoadAttachmentsTimer,
+ DownloadStartMS,
+ WantIt->ThinBlockDescription.BlockHash,
+ AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ 3);
+ for (uint32_t BlockChunkIndex : WantIt->NeededChunkIndexes)
+ {
+ const IoHash& ChunkHash = WantIt->ThinBlockDescription.ChunkRawHashes[BlockChunkIndex];
+ auto It = AllNeededPartialChunkHashesLookup.find(ChunkHash);
+ ZEN_ASSERT(It != AllNeededPartialChunkHashesLookup.end());
+ uint32_t ChunkIndex = It->second;
+ DownloadedViaLegacyChunkFlag[ChunkIndex] = true;
+ }
+ WantIt++;
+ }
+ else if (WantIt->ThinBlockDescription.BlockHash == FindIt->BlockHash)
+ {
+ // Found
+ FindIt++;
+ WantIt++;
+ }
+ else
+ {
+ // Not a requested block?
+ ZEN_ASSERT(false);
}
}
- uint64_t PartialTransferWallTimeMS = TransferWallTimeMS;
- if (DownloadStartMS != (uint64_t)-1)
+ }
+ if (!AllNeededChunkHashes.empty())
+ {
+ std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> PartialBlockDownloadModes;
+ std::vector<bool> BlockExistsInCache(BlocksWithDescription.size(), false);
+
+ if (Context.PartialBlockRequestMode == EPartialBlockRequestMode::Off)
{
- PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load();
+ PartialBlockDownloadModes.resize(BlocksWithDescription.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off);
+ }
+ else
+ {
+ if (Context.OptionalCache)
+ {
+ std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult =
+ Context.OptionalCache->BlobsExists(Context.CacheBuildId, BlocksWithDescription);
+ if (CacheExistsResult.size() == BlocksWithDescription.size())
+ {
+ for (size_t BlobIndex = 0; BlobIndex < CacheExistsResult.size(); BlobIndex++)
+ {
+ BlockExistsInCache[BlobIndex] = CacheExistsResult[BlobIndex].HasBody;
+ }
+ }
+ uint64_t FoundBlocks =
+ std::accumulate(BlockExistsInCache.begin(),
+ BlockExistsInCache.end(),
+ uint64_t(0u),
+ [](uint64_t Current, bool Exists) -> uint64_t { return Current + (Exists ? 1 : 0); });
+ if (FoundBlocks > 0)
+ {
+ remotestore_impl::ReportMessage(
+ Context.OptionalJobContext,
+ fmt::format("Found {} out of {} blocks in cache", FoundBlocks, BlockExistsInCache.size()));
+ }
+ }
+
+ ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+ ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+
+ switch (Context.PartialBlockRequestMode)
+ {
+ case EPartialBlockRequestMode::Off:
+ break;
+ case EPartialBlockRequestMode::ZenCacheOnly:
+ CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+ break;
+ case EPartialBlockRequestMode::Mixed:
+ CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
+ break;
+ case EPartialBlockRequestMode::All:
+ CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ CloudPartialDownloadMode = Context.StoreMaxRangeCountPerRequest > 1
+ ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
+ break;
+ }
+
+ PartialBlockDownloadModes.reserve(BlocksWithDescription.size());
+ for (uint32_t BlockIndex = 0; BlockIndex < BlocksWithDescription.size(); BlockIndex++)
+ {
+ const bool BlockExistInCache = BlockExistsInCache[BlockIndex];
+ PartialBlockDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode);
+ }
+ }
+
+ ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size());
+
+ ChunkBlockAnalyser PartialAnalyser(
+ *LogOutput,
+ BlockDescriptions.Blocks,
+ ChunkBlockAnalyser::Options{.IsQuiet = false,
+ .IsVerbose = false,
+ .HostLatencySec = Context.StoreLatencySec,
+ .HostHighSpeedLatencySec = Context.CacheLatencySec,
+ .HostMaxRangeCountPerRequest = Context.StoreMaxRangeCountPerRequest,
+ .HostHighSpeedMaxRangeCountPerRequest = Context.CacheMaxRangeCountPerRequest});
+
+ std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks =
+ PartialAnalyser.GetNeeded(AllNeededPartialChunkHashesLookup,
+ [&](uint32_t ChunkIndex) { return !DownloadedViaLegacyChunkFlag[ChunkIndex]; });
+
+ PartialBlocksResult = PartialAnalyser.CalculatePartialBlockDownloads(NeededBlocks, PartialBlockDownloadModes);
+
+ for (uint32_t FullBlockIndex : PartialBlocksResult.FullBlockIndexes)
+ {
+ DownloadAndSaveBlock(Context,
+ AttachmentsDownloadLatch,
+ AttachmentsWriteLatch,
+ RemoteResult,
+ Info,
+ LoadAttachmentsTimer,
+ DownloadStartMS,
+ BlockDescriptions.Blocks[FullBlockIndex].BlockHash,
+ AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ 3);
+ }
+
+ for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocksResult.BlockRanges.size();)
+ {
+ size_t RangeCount = 1;
+ size_t RangesLeft = PartialBlocksResult.BlockRanges.size() - BlockRangeIndex;
+ const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocksResult.BlockRanges[BlockRangeIndex];
+ while (RangeCount < RangesLeft &&
+ CurrentBlockRange.BlockIndex == PartialBlocksResult.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex)
+ {
+ RangeCount++;
+ }
+
+ DownloadAndSavePartialBlock(Context,
+ AttachmentsDownloadLatch,
+ AttachmentsWriteLatch,
+ RemoteResult,
+ Info,
+ LoadAttachmentsTimer,
+ DownloadStartMS,
+ BlockDescriptions.Blocks[CurrentBlockRange.BlockIndex],
+ BlockExistsInCache[CurrentBlockRange.BlockIndex],
+ PartialBlocksResult.BlockRanges,
+ BlockRangeIndex,
+ RangeCount,
+ AllNeededPartialChunkHashesLookup,
+ ChunkDownloadedFlags,
+ /* RetriesLeft*/ 3);
+
+ BlockRangeIndex += RangeCount;
+ }
+ }
+
+ AttachmentsDownloadLatch.CountDown();
+ {
+ ptrdiff_t AttachmentCountToUseForProgress = AttachmentsDownloadLatch.Remaining();
+ while (!AttachmentsDownloadLatch.Wait(1000))
+ {
+ ptrdiff_t Remaining = AttachmentsDownloadLatch.Remaining();
+ if (remotestore_impl::IsCancelled(Context.OptionalJobContext))
+ {
+ if (!RemoteResult.IsError())
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
+ }
+ }
+ uint64_t PartialTransferWallTimeMS = TransferWallTimeMS;
+ if (DownloadStartMS != (uint64_t)-1)
+ {
+ PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load();
+ }
+
+ uint64_t AttachmentsDownloaded =
+ Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load();
+ uint64_t AttachmentBytesDownloaded = Info.AttachmentBlockBytesDownloaded.load() +
+ Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load();
+
+ AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress);
+ remotestore_impl::ReportProgress(
+ Context.OptionalJobContext,
+ "Loading attachments"sv,
+ fmt::format(
+ "{} ({}) downloaded, {} ({}) stored, {} remaining. {}",
+ AttachmentsDownloaded,
+ NiceBytes(AttachmentBytesDownloaded),
+ Info.AttachmentsStored.load(),
+ NiceBytes(Info.AttachmentBytesStored.load()),
+ Remaining,
+ remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, PartialTransferWallTimeMS)),
+ AttachmentCountToUseForProgress,
+ Remaining);
}
- remotestore_impl::ReportProgress(
- OptionalContext,
- "Loading attachments"sv,
- fmt::format("{} remaining. {}", Remaining, remotestore_impl::GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)),
- AttachmentCount.load(),
- Remaining);
}
if (DownloadStartMS != (uint64_t)-1)
{
@@ -3195,57 +3864,58 @@ LoadOplog(CidStore& ChunkStore,
if (AttachmentCount.load() > 0)
{
- remotestore_impl::ReportProgress(OptionalContext,
- "Loading attachments"sv,
- fmt::format("{}", remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS)),
- AttachmentCount.load(),
- 0);
+ remotestore_impl::ReportProgress(
+ Context.OptionalJobContext,
+ "Loading attachments"sv,
+ fmt::format("{}", remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, TransferWallTimeMS)),
+ AttachmentCount.load(),
+ 0);
}
AttachmentsWriteLatch.CountDown();
- while (!AttachmentsWriteLatch.Wait(1000))
{
- ptrdiff_t Remaining = AttachmentsWriteLatch.Remaining();
- if (remotestore_impl::IsCancelled(OptionalContext))
+ ptrdiff_t AttachmentCountToUseForProgress = AttachmentsWriteLatch.Remaining();
+ while (!AttachmentsWriteLatch.Wait(1000))
{
- if (!RemoteResult.IsError())
+ ptrdiff_t Remaining = AttachmentsWriteLatch.Remaining();
+ if (remotestore_impl::IsCancelled(Context.OptionalJobContext))
{
- RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
+ if (!RemoteResult.IsError())
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
+ }
}
+ AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress);
+ remotestore_impl::ReportProgress(Context.OptionalJobContext,
+ "Writing attachments"sv,
+ fmt::format("{} ({}), {} remaining.",
+ Info.AttachmentsStored.load(),
+ NiceBytes(Info.AttachmentBytesStored.load()),
+ Remaining),
+ AttachmentCountToUseForProgress,
+ Remaining);
}
- remotestore_impl::ReportProgress(OptionalContext,
- "Writing attachments"sv,
- fmt::format("{} remaining.", Remaining),
- AttachmentCount.load(),
- Remaining);
}
if (AttachmentCount.load() > 0)
{
- remotestore_impl::ReportProgress(OptionalContext, "Writing attachments", ""sv, AttachmentCount.load(), 0);
+ remotestore_impl::ReportProgress(Context.OptionalJobContext, "Writing attachments", ""sv, AttachmentCount.load(), 0);
}
if (Result.ErrorCode == 0)
{
if (!FilesToDechunk.empty())
{
- remotestore_impl::ReportMessage(OptionalContext, fmt::format("Dechunking {} attachments", FilesToDechunk.size()));
+ remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Dechunking {} attachments", FilesToDechunk.size()));
Latch DechunkLatch(1);
- std::filesystem::path TempFilePath = Oplog.TempPath();
+ std::filesystem::path TempFilePath = Context.Oplog.TempPath();
for (const ChunkedInfo& Chunked : FilesToDechunk)
{
std::filesystem::path TempFileName = TempFilePath / Chunked.RawHash.ToHexString();
DechunkLatch.AddCount(1);
- WorkerPool.ScheduleWork(
- [&ChunkStore,
- &DechunkLatch,
- TempFileName,
- &Chunked,
- &RemoteResult,
- IgnoreMissingAttachments,
- &Info,
- OptionalContext]() {
+ Context.WorkerPool.ScheduleWork(
+ [&Context, &DechunkLatch, TempFileName, &Chunked, &RemoteResult, &Info]() {
ZEN_TRACE_CPU("DechunkAttachment");
auto _ = MakeGuard([&DechunkLatch, &TempFileName] {
@@ -3279,16 +3949,16 @@ LoadOplog(CidStore& ChunkStore,
for (std::uint32_t SequenceIndex : Chunked.ChunkSequence)
{
const IoHash& ChunkHash = Chunked.ChunkHashes[SequenceIndex];
- IoBuffer Chunk = ChunkStore.FindChunkByCid(ChunkHash);
+ IoBuffer Chunk = Context.ChunkStore.FindChunkByCid(ChunkHash);
if (!Chunk)
{
remotestore_impl::ReportMessage(
- OptionalContext,
+ Context.OptionalJobContext,
fmt::format("Missing chunk {} for chunked attachment {}", ChunkHash, Chunked.RawHash));
// We only add 1 as the resulting missing count will be 1 for the dechunked file
Info.MissingAttachmentCount.fetch_add(1);
- if (!IgnoreMissingAttachments)
+ if (!Context.IgnoreMissingAttachments)
{
RemoteResult.SetError(
gsl::narrow<int>(HttpResponseCode::NotFound),
@@ -3306,7 +3976,7 @@ LoadOplog(CidStore& ChunkStore,
if (RawHash != ChunkHash)
{
remotestore_impl::ReportMessage(
- OptionalContext,
+ Context.OptionalJobContext,
fmt::format("Mismatching raw hash {} for chunk {} for chunked attachment {}",
RawHash,
ChunkHash,
@@ -3314,7 +3984,7 @@ LoadOplog(CidStore& ChunkStore,
// We only add 1 as the resulting missing count will be 1 for the dechunked file
Info.MissingAttachmentCount.fetch_add(1);
- if (!IgnoreMissingAttachments)
+ if (!Context.IgnoreMissingAttachments)
{
RemoteResult.SetError(
gsl::narrow<int>(HttpResponseCode::NotFound),
@@ -3351,14 +4021,14 @@ LoadOplog(CidStore& ChunkStore,
}))
{
remotestore_impl::ReportMessage(
- OptionalContext,
+ Context.OptionalJobContext,
fmt::format("Failed to decompress chunk {} for chunked attachment {}",
ChunkHash,
Chunked.RawHash));
// We only add 1 as the resulting missing count will be 1 for the dechunked file
Info.MissingAttachmentCount.fetch_add(1);
- if (!IgnoreMissingAttachments)
+ if (!Context.IgnoreMissingAttachments)
{
RemoteResult.SetError(
gsl::narrow<int>(HttpResponseCode::NotFound),
@@ -3380,11 +4050,12 @@ LoadOplog(CidStore& ChunkStore,
TmpFile.Close();
TmpBuffer = IoBufferBuilder::MakeFromTemporaryFile(TempFileName);
}
+ uint64_t TmpBufferSize = TmpBuffer.GetSize();
CidStore::InsertResult InsertResult =
- ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace);
+ Context.ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace);
if (InsertResult.New)
{
- Info.AttachmentBytesStored.fetch_add(TmpBuffer.GetSize());
+ Info.AttachmentBytesStored.fetch_add(TmpBufferSize);
Info.AttachmentsStored.fetch_add(1);
}
@@ -3407,54 +4078,58 @@ LoadOplog(CidStore& ChunkStore,
while (!DechunkLatch.Wait(1000))
{
ptrdiff_t Remaining = DechunkLatch.Remaining();
- if (remotestore_impl::IsCancelled(OptionalContext))
+ if (remotestore_impl::IsCancelled(Context.OptionalJobContext))
{
if (!RemoteResult.IsError())
{
RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::OK), "Operation cancelled", "");
remotestore_impl::ReportMessage(
- OptionalContext,
+ Context.OptionalJobContext,
fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason()));
}
}
- remotestore_impl::ReportProgress(OptionalContext,
+ remotestore_impl::ReportProgress(Context.OptionalJobContext,
"Dechunking attachments"sv,
fmt::format("{} remaining...", Remaining),
FilesToDechunk.size(),
Remaining);
}
- remotestore_impl::ReportProgress(OptionalContext, "Dechunking attachments"sv, ""sv, FilesToDechunk.size(), 0);
+ remotestore_impl::ReportProgress(Context.OptionalJobContext, "Dechunking attachments"sv, ""sv, FilesToDechunk.size(), 0);
}
Result = RemoteResult.ConvertResult();
}
if (Result.ErrorCode == 0)
{
- if (CleanOplog)
+ if (Context.CleanOplog)
{
- RemoteStore.Flush();
- if (!Oplog.Reset())
+ if (Context.OptionalCache)
+ {
+ Context.OptionalCache->Flush(100, [](intptr_t) { return /*DontWaitForPendingOperation*/ false; });
+ }
+ if (!Context.Oplog.Reset())
{
Result = RemoteProjectStore::Result{.ErrorCode = gsl::narrow<int>(HttpResponseCode::InternalServerError),
.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0,
- .Reason = fmt::format("Failed to clean existing oplog '{}'", Oplog.OplogId())};
- remotestore_impl::ReportMessage(OptionalContext, fmt::format("Aborting ({}): {}", Result.ErrorCode, Result.Reason));
+ .Reason = fmt::format("Failed to clean existing oplog '{}'", Context.Oplog.OplogId())};
+ remotestore_impl::ReportMessage(Context.OptionalJobContext,
+ fmt::format("Aborting ({}): {}", Result.ErrorCode, Result.Reason));
}
}
if (Result.ErrorCode == 0)
{
- remotestore_impl::WriteOplogSection(Oplog, OplogSection, OptionalContext);
+ remotestore_impl::WriteOplogSection(Context.Oplog, OplogSection, Context.OptionalJobContext);
}
}
Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0;
- remotestore_impl::LogRemoteStoreStatsDetails(RemoteStore.GetStats());
+ remotestore_impl::LogRemoteStoreStatsDetails(Context.RemoteStore.GetStats());
{
std::string DownloadDetails;
RemoteProjectStore::ExtendedStats ExtendedStats;
- if (RemoteStore.GetExtendedStats(ExtendedStats))
+ if (Context.RemoteStore.GetExtendedStats(ExtendedStats))
{
if (!ExtendedStats.m_ReceivedBytesPerSource.empty())
{
@@ -3473,26 +4148,37 @@ LoadOplog(CidStore& ChunkStore,
Total += It.second;
}
- remotestore_impl::ReportMessage(OptionalContext, fmt::format("Downloaded {} ({})", NiceBytes(Total), SB.ToView()));
+ remotestore_impl::ReportMessage(Context.OptionalJobContext,
+ fmt::format("Downloaded {} ({})", NiceBytes(Total), SB.ToView()));
}
}
}
+ uint64_t TotalDownloads =
+ 1 + Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load();
+ uint64_t TotalBytesDownloaded = Info.OplogSizeBytes + Info.AttachmentBlockBytesDownloaded.load() +
+ Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load();
+
remotestore_impl::ReportMessage(
- OptionalContext,
- fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}), Stored: {} ({}), Missing: {} {}",
+ Context.OptionalJobContext,
+ fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), BlockRanges: {} ({}), Attachments: {} "
+ "({}), Total: {} ({}), Stored: {} ({}), Missing: {} {}",
RemoteStoreInfo.ContainerName,
Result.ErrorCode == 0 ? "SUCCESS" : "FAILURE",
NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)),
NiceBytes(Info.OplogSizeBytes),
Info.AttachmentBlocksDownloaded.load(),
NiceBytes(Info.AttachmentBlockBytesDownloaded.load()),
+ Info.AttachmentBlocksRangesDownloaded.load(),
+ NiceBytes(Info.AttachmentBlockRangeBytesDownloaded.load()),
Info.AttachmentsDownloaded.load(),
NiceBytes(Info.AttachmentBytesDownloaded.load()),
+ TotalDownloads,
+ NiceBytes(TotalBytesDownloaded),
Info.AttachmentsStored.load(),
NiceBytes(Info.AttachmentBytesStored.load()),
Info.MissingAttachmentCount.load(),
- remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS)));
+ remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, TransferWallTimeMS)));
return Result;
}
@@ -3537,7 +4223,7 @@ RemoteProjectStore::~RemoteProjectStore()
#if ZEN_WITH_TESTS
-namespace testutils {
+namespace projectstore_testutils {
using namespace std::literals;
static std::string OidAsString(const Oid& Id)
@@ -3589,7 +4275,29 @@ namespace testutils {
return Result;
}
-} // namespace testutils
+ class TestJobContext : public JobContext
+ {
+ public:
+ explicit TestJobContext(int& OpIndex) : m_OpIndex(OpIndex) {}
+ virtual bool IsCancelled() const { return false; }
+ virtual void ReportMessage(std::string_view Message) { ZEN_INFO("Job {}: {}", m_OpIndex, Message); }
+ virtual void ReportProgress(std::string_view CurrentOp, std::string_view Details, ptrdiff_t TotalCount, ptrdiff_t RemainingCount)
+ {
+ ZEN_INFO("Job {}: Op '{}'{} {}/{}",
+ m_OpIndex,
+ CurrentOp,
+ Details.empty() ? "" : fmt::format(" {}", Details),
+ TotalCount - RemainingCount,
+ TotalCount);
+ }
+
+ private:
+ int& m_OpIndex;
+ };
+
+} // namespace projectstore_testutils
+
+TEST_SUITE_BEGIN("remotestore.projectstore");
struct ExportForceDisableBlocksTrue_ForceTempBlocksFalse
{
@@ -3616,7 +4324,7 @@ TEST_CASE_TEMPLATE("project.store.export",
ExportForceDisableBlocksFalse_ForceTempBlocksTrue)
{
using namespace std::literals;
- using namespace testutils;
+ using namespace projectstore_testutils;
ScopedTemporaryDirectory TempDir;
ScopedTemporaryDirectory ExportDir;
@@ -3684,56 +4392,712 @@ TEST_CASE_TEMPLATE("project.store.export",
false,
nullptr);
- CHECK(ExportResult.ErrorCode == 0);
+ REQUIRE(ExportResult.ErrorCode == 0);
Ref<ProjectStore::Oplog> OplogImport = Project->NewOplog("oplog2", {});
CHECK(OplogImport);
- RemoteProjectStore::Result ImportResult = LoadOplog(CidStore,
- *RemoteStore,
- *OplogImport,
- NetworkPool,
- WorkerPool,
- /*Force*/ false,
- /*IgnoreMissingAttachments*/ false,
- /*CleanOplog*/ false,
- nullptr);
+ int OpJobIndex = 0;
+ TestJobContext OpJobContext(OpJobIndex);
+
+ RemoteProjectStore::Result ImportResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ .RemoteStore = *RemoteStore,
+ .OptionalCache = nullptr,
+ .CacheBuildId = Oid::Zero,
+ .Oplog = *OplogImport,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed,
+ .OptionalJobContext = &OpJobContext});
CHECK(ImportResult.ErrorCode == 0);
-
- RemoteProjectStore::Result ImportForceResult = LoadOplog(CidStore,
- *RemoteStore,
- *OplogImport,
- NetworkPool,
- WorkerPool,
- /*Force*/ true,
- /*IgnoreMissingAttachments*/ false,
- /*CleanOplog*/ false,
- nullptr);
+ OpJobIndex++;
+
+ RemoteProjectStore::Result ImportForceResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ .RemoteStore = *RemoteStore,
+ .OptionalCache = nullptr,
+ .CacheBuildId = Oid::Zero,
+ .Oplog = *OplogImport,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = true,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed,
+ .OptionalJobContext = &OpJobContext});
CHECK(ImportForceResult.ErrorCode == 0);
-
- RemoteProjectStore::Result ImportCleanResult = LoadOplog(CidStore,
- *RemoteStore,
- *OplogImport,
- NetworkPool,
- WorkerPool,
- /*Force*/ false,
- /*IgnoreMissingAttachments*/ false,
- /*CleanOplog*/ true,
- nullptr);
+ OpJobIndex++;
+
+ RemoteProjectStore::Result ImportCleanResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ .RemoteStore = *RemoteStore,
+ .OptionalCache = nullptr,
+ .CacheBuildId = Oid::Zero,
+ .Oplog = *OplogImport,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = true,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed,
+ .OptionalJobContext = &OpJobContext});
CHECK(ImportCleanResult.ErrorCode == 0);
-
- RemoteProjectStore::Result ImportForceCleanResult = LoadOplog(CidStore,
- *RemoteStore,
- *OplogImport,
- NetworkPool,
- WorkerPool,
- /*Force*/ true,
- /*IgnoreMissingAttachments*/ false,
- /*CleanOplog*/ true,
- nullptr);
+ OpJobIndex++;
+
+ RemoteProjectStore::Result ImportForceCleanResult =
+ LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ .RemoteStore = *RemoteStore,
+ .OptionalCache = nullptr,
+ .CacheBuildId = Oid::Zero,
+ .Oplog = *OplogImport,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = true,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = true,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed,
+ .OptionalJobContext = &OpJobContext});
CHECK(ImportForceCleanResult.ErrorCode == 0);
+ OpJobIndex++;
}
+// Common oplog setup used by the two tests below.
+// Returns a FileRemoteStore backed by ExportDir that has been populated with a SaveOplog call.
+// Keeps the test data identical to project.store.export so the two test suites exercise the same blocks/attachments.
+static RemoteProjectStore::Result
+SetupExportStore(CidStore& CidStore,
+ ProjectStore::Project& Project,
+ WorkerThreadPool& NetworkPool,
+ WorkerThreadPool& WorkerPool,
+ const std::filesystem::path& ExportDir,
+ std::shared_ptr<RemoteProjectStore>& OutRemoteStore)
+{
+ using namespace projectstore_testutils;
+ using namespace std::literals;
+
+ Ref<ProjectStore::Oplog> Oplog = Project.NewOplog("oplog_export", {});
+ if (!Oplog)
+ {
+ return RemoteProjectStore::Result{.ErrorCode = -1};
+ }
+
+ Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), {}));
+ Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{77})));
+ Oplog->AppendNewOplogEntry(
+ CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{7123, 583, 690, 99})));
+ Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{55, 122})));
+ Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(
+ Oid::NewOid(),
+ CreateAttachments(std::initializer_list<size_t>{256u * 1024u, 92u * 1024u}, OodleCompressionLevel::None)));
+
+ FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = 64u * 1024,
+ .MaxChunksPerBlock = 1000,
+ .MaxChunkEmbedSize = 32 * 1024u,
+ .ChunkFileSizeLimit = 64u * 1024u},
+ /*.FolderPath =*/ExportDir,
+ /*.Name =*/std::string("oplog_export"),
+ /*.OptionalBaseName =*/std::string(),
+ /*.ForceDisableBlocks =*/false,
+ /*.ForceEnableTempBlocks =*/false};
+
+ OutRemoteStore = CreateFileRemoteStore(Log(), Options);
+ return SaveOplog(CidStore,
+ *OutRemoteStore,
+ Project,
+ *Oplog,
+ NetworkPool,
+ WorkerPool,
+ Options.MaxBlockSize,
+ Options.MaxChunksPerBlock,
+ Options.MaxChunkEmbedSize,
+ Options.ChunkFileSizeLimit,
+ /*EmbedLooseFiles*/ true,
+ /*ForceUpload*/ false,
+ /*IgnoreMissingAttachments*/ false,
+ /*OptionalContext*/ nullptr);
+}
+
+// Creates an export store with a single oplog entry that packs six 512 KB chunks into one
+// ~3 MB block (MaxBlockSize = 8 MB). The resulting block slack (~1.5 MB) far exceeds the
+// 512 KB threshold that ChunkBlockAnalyser requires before it will consider partial-block
+// downloads instead of full-block downloads.
+//
+// This function is self-contained: it creates its own GcManager, CidStore, ProjectStore and
+// Project internally so that each call is independent of any outer test context. After
+// SaveOplog returns, all persistent data lives on disk inside ExportDir and the caller can
+// freely query OutRemoteStore without holding any references to the internal context.
+static RemoteProjectStore::Result
+SetupPartialBlockExportStore(WorkerThreadPool& NetworkPool,
+ WorkerThreadPool& WorkerPool,
+ const std::filesystem::path& ExportDir,
+ std::shared_ptr<RemoteProjectStore>& OutRemoteStore)
+{
+ using namespace projectstore_testutils;
+ using namespace std::literals;
+
+ // Self-contained CAS and project store. Subdirectories of ExportDir keep everything
+ // together without relying on the outer TEST_CASE's ExportCidStore / ExportProject.
+ GcManager LocalGc;
+ CidStore LocalCidStore(LocalGc);
+ CidStoreConfiguration LocalCidConfig = {.RootDirectory = ExportDir / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ LocalCidStore.Initialize(LocalCidConfig);
+
+ std::filesystem::path LocalProjectBasePath = ExportDir / "proj";
+ ProjectStore LocalProjectStore(LocalCidStore, LocalProjectBasePath, LocalGc, ProjectStore::Configuration{});
+ Ref<ProjectStore::Project> LocalProject(LocalProjectStore.NewProject(LocalProjectBasePath / "p"sv,
+ "p"sv,
+ (ExportDir / "root").string(),
+ (ExportDir / "engine").string(),
+ (ExportDir / "game").string(),
+ (ExportDir / "game" / "game.uproject").string()));
+
+ Ref<ProjectStore::Oplog> Oplog = LocalProject->NewOplog("oplog_partial_block", {});
+ if (!Oplog)
+ {
+ return RemoteProjectStore::Result{.ErrorCode = -1};
+ }
+
+ // Six 512 KB chunks with OodleCompressionLevel::None so the compressed size stays large
+ // and the block genuinely exceeds the 512 KB slack threshold.
+ Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(
+ Oid::NewOid(),
+ CreateAttachments(std::initializer_list<size_t>{512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u},
+ OodleCompressionLevel::None)));
+
+ // MaxChunkEmbedSize must be larger than the compressed size of each 512 KB chunk
+ // (OodleCompressionLevel::None → compressed ≈ raw ≈ 512 KB). With the legacy
+ // 32 KB limit all six chunks would become loose large attachments and no block would
+ // be created, so we use the production default of 1.5 MB instead.
+ FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = 8u * 1024u * 1024u,
+ .MaxChunksPerBlock = 1000,
+ .MaxChunkEmbedSize = RemoteStoreOptions::DefaultMaxChunkEmbedSize,
+ .ChunkFileSizeLimit = 64u * 1024u * 1024u},
+ /*.FolderPath =*/ExportDir,
+ /*.Name =*/std::string("oplog_partial_block"),
+ /*.OptionalBaseName =*/std::string(),
+ /*.ForceDisableBlocks =*/false,
+ /*.ForceEnableTempBlocks =*/false};
+ OutRemoteStore = CreateFileRemoteStore(Log(), Options);
+ return SaveOplog(LocalCidStore,
+ *OutRemoteStore,
+ *LocalProject,
+ *Oplog,
+ NetworkPool,
+ WorkerPool,
+ Options.MaxBlockSize,
+ Options.MaxChunksPerBlock,
+ Options.MaxChunkEmbedSize,
+ Options.ChunkFileSizeLimit,
+ /*EmbedLooseFiles*/ true,
+ /*ForceUpload*/ false,
+ /*IgnoreMissingAttachments*/ false,
+ /*OptionalContext*/ nullptr);
+}
+
+// Returns the first block hash that has at least MinChunkCount chunks, or a zero IoHash
+// if no qualifying block exists in Store.
+static IoHash
+FindBlockWithMultipleChunks(RemoteProjectStore& Store, size_t MinChunkCount)
+{
+ RemoteProjectStore::LoadContainerResult ContainerResult = Store.LoadContainer();
+ if (ContainerResult.ErrorCode != 0)
+ {
+ return {};
+ }
+ std::vector<IoHash> BlockHashes = GetBlockHashesFromOplog(ContainerResult.ContainerObject);
+ if (BlockHashes.empty())
+ {
+ return {};
+ }
+ RemoteProjectStore::GetBlockDescriptionsResult Descriptions = Store.GetBlockDescriptions(BlockHashes, nullptr, Oid{});
+ if (Descriptions.ErrorCode != 0)
+ {
+ return {};
+ }
+ for (const ChunkBlockDescription& Desc : Descriptions.Blocks)
+ {
+ if (Desc.ChunkRawHashes.size() >= MinChunkCount)
+ {
+ return Desc.BlockHash;
+ }
+ }
+ return {};
+}
+
+// Loads BlockHash from Source and inserts every even-indexed chunk (0, 2, 4, …) into
+// TargetCidStore. Odd-indexed chunks are left absent so that when an import is run
+// against the same block, HasAttachment returns false for three non-adjacent positions
+// — the minimum needed to exercise the multi-range partial-block download paths.
+static void
+SeedCidStoreWithAlternateChunks(CidStore& TargetCidStore, RemoteProjectStore& Source, const IoHash& BlockHash)
+{
+ RemoteProjectStore::LoadAttachmentResult BlockResult = Source.LoadAttachment(BlockHash);
+ if (BlockResult.ErrorCode != 0 || !BlockResult.Bytes)
+ {
+ return;
+ }
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(BlockResult.Bytes), RawHash, RawSize);
+ if (!Compressed)
+ {
+ return;
+ }
+ CompositeBuffer BlockPayload = Compressed.DecompressToComposite();
+ if (!BlockPayload)
+ {
+ return;
+ }
+
+ uint32_t ChunkIndex = 0;
+ uint64_t HeaderSize = 0;
+ IterateChunkBlock(
+ BlockPayload.Flatten(),
+ [&TargetCidStore, &ChunkIndex](CompressedBuffer&& Chunk, const IoHash& AttachmentHash) {
+ if (ChunkIndex % 2 == 0)
+ {
+ IoBuffer ChunkData = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ TargetCidStore.AddChunk(ChunkData, AttachmentHash);
+ }
+ ++ChunkIndex;
+ },
+ HeaderSize);
+}
+
+TEST_CASE("project.store.import.context_settings")
+{
+ using namespace std::literals;
+ using namespace projectstore_testutils;
+
+ ScopedTemporaryDirectory TempDir;
+ ScopedTemporaryDirectory ExportDir;
+
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+ std::filesystem::path ProjectRootDir = TempDir.Path() / "game";
+ std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject";
+
+ // Export-side CAS and project store: used only by SetupExportStore to build the remote store
+ // payload. Kept separate from the import side so the two CAS instances are disjoint.
+ GcManager ExportGc;
+ CidStore ExportCidStore(ExportGc);
+ CidStoreConfiguration ExportCidConfig = {.RootDirectory = TempDir.Path() / "export_cas",
+ .TinyValueThreshold = 1024,
+ .HugeValueThreshold = 4096};
+ ExportCidStore.Initialize(ExportCidConfig);
+
+ std::filesystem::path ExportBasePath = TempDir.Path() / "export_projectstore";
+ ProjectStore ExportProjectStore(ExportCidStore, ExportBasePath, ExportGc, ProjectStore::Configuration{});
+ Ref<ProjectStore::Project> ExportProject(ExportProjectStore.NewProject(ExportBasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+
+ uint32_t NetworkWorkerCount = Max(GetHardwareConcurrency() / 4u, 2u);
+ uint32_t WorkerCount = (NetworkWorkerCount < GetHardwareConcurrency()) ? Max(GetHardwareConcurrency() - NetworkWorkerCount, 4u) : 4u;
+ WorkerThreadPool WorkerPool(WorkerCount);
+ WorkerThreadPool NetworkPool(NetworkWorkerCount);
+
+ std::shared_ptr<RemoteProjectStore> RemoteStore;
+ RemoteProjectStore::Result ExportResult =
+ SetupExportStore(ExportCidStore, *ExportProject, NetworkPool, WorkerPool, ExportDir.Path(), RemoteStore);
+ REQUIRE(ExportResult.ErrorCode == 0);
+
+ // Import-side CAS and project store: starts empty, mirroring a fresh machine that has never
+ // downloaded the data. HasAttachment() therefore returns false for every chunk, so the import
+ // genuinely contacts the remote store without needing ForceDownload on the populate pass.
+ GcManager ImportGc;
+ CidStore ImportCidStore(ImportGc);
+ CidStoreConfiguration ImportCidConfig = {.RootDirectory = TempDir.Path() / "import_cas",
+ .TinyValueThreshold = 1024,
+ .HugeValueThreshold = 4096};
+ ImportCidStore.Initialize(ImportCidConfig);
+
+ std::filesystem::path ImportBasePath = TempDir.Path() / "import_projectstore";
+ ProjectStore ImportProjectStore(ImportCidStore, ImportBasePath, ImportGc, ProjectStore::Configuration{});
+ Ref<ProjectStore::Project> ImportProject(ImportProjectStore.NewProject(ImportBasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+
+ const Oid CacheBuildId = Oid::NewOid();
+ BuildStorageCache::Statistics CacheStats;
+ std::unique_ptr<BuildStorageCache> Cache = CreateInMemoryBuildStorageCache(256u, CacheStats);
+ auto ResetCacheStats = [&]() {
+ CacheStats.TotalBytesRead = 0;
+ CacheStats.TotalBytesWritten = 0;
+ CacheStats.TotalRequestCount = 0;
+ CacheStats.TotalRequestTimeUs = 0;
+ CacheStats.TotalExecutionTimeUs = 0;
+ CacheStats.PeakSentBytes = 0;
+ CacheStats.PeakReceivedBytes = 0;
+ CacheStats.PeakBytesPerSec = 0;
+ CacheStats.PutBlobCount = 0;
+ CacheStats.PutBlobByteCount = 0;
+ };
+
+ int OpJobIndex = 0;
+
+ TestJobContext OpJobContext(OpJobIndex);
+
+ // Helper: run a LoadOplog against the import-side CAS/project with the given context knobs.
+ // Each call creates a fresh oplog so repeated calls within one SUBCASE don't short-circuit on
+ // already-present data.
+ auto DoImport = [&](BuildStorageCache* OptCache,
+ EPartialBlockRequestMode Mode,
+ double StoreLatency,
+ uint64_t StoreRanges,
+ double CacheLatency,
+ uint64_t CacheRanges,
+ bool PopulateCache,
+ bool ForceDownload) -> RemoteProjectStore::Result {
+ Ref<ProjectStore::Oplog> ImportOplog = ImportProject->NewOplog(fmt::format("import_{}", OpJobIndex++), {});
+ return LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ .RemoteStore = *RemoteStore,
+ .OptionalCache = OptCache,
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *ImportOplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = ForceDownload,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = Mode,
+ .PopulateCache = PopulateCache,
+ .StoreLatencySec = StoreLatency,
+ .StoreMaxRangeCountPerRequest = StoreRanges,
+ .CacheLatencySec = CacheLatency,
+ .CacheMaxRangeCountPerRequest = CacheRanges,
+ .OptionalJobContext = &OpJobContext});
+ };
+
+ // Shorthand: Mode=All, low latency, 128 ranges for both store and cache.
+ auto ImportAll = [&](BuildStorageCache* OptCache, bool Populate, bool Force) {
+ return DoImport(OptCache, EPartialBlockRequestMode::All, 0.001, 128u, 0.001, 128u, Populate, Force);
+ };
+
+ SUBCASE("mode_off_no_cache")
+ {
+ // Baseline: no partial block requests, no cache.
+ RemoteProjectStore::Result R =
+ DoImport(nullptr, EPartialBlockRequestMode::Off, -1.0, (uint64_t)-1, -1.0, (uint64_t)-1, false, false);
+ CHECK(R.ErrorCode == 0);
+ }
+
+ SUBCASE("mode_all_multirange_cloud_no_cache")
+ {
+ // StoreMaxRangeCountPerRequest > 1 → MultiRange cloud path.
+ RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::All, 0.001, 128u, -1.0, 0u, false, false);
+ CHECK(R.ErrorCode == 0);
+ }
+
+ SUBCASE("mode_all_singlerange_cloud_no_cache")
+ {
+ // StoreMaxRangeCountPerRequest == 1 → SingleRange cloud path.
+ RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::All, 0.001, 1u, -1.0, 0u, false, false);
+ CHECK(R.ErrorCode == 0);
+ }
+
+ SUBCASE("mode_mixed_high_latency_no_cache")
+ {
+ // High store latency encourages range merging; Mixed uses SingleRange for cloud, Off for cache.
+ RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::Mixed, 0.1, 128u, -1.0, 0u, false, false);
+ CHECK(R.ErrorCode == 0);
+ }
+
+ SUBCASE("cache_populate_and_hit")
+ {
+ // First import: ImportCidStore is empty so all blocks are downloaded from the remote store
+ // and written to the cache.
+ RemoteProjectStore::Result PopulateResult = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false);
+ CHECK(PopulateResult.ErrorCode == 0);
+ CHECK(CacheStats.PutBlobCount > 0);
+
+ // Re-import with ForceDownload=true: all chunks are now in ImportCidStore but Force overrides
+ // HasAttachment() so the download logic re-runs and serves blocks from the cache instead of
+ // the remote store.
+ ResetCacheStats();
+ RemoteProjectStore::Result HitResult = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/true);
+ CHECK(HitResult.ErrorCode == 0);
+ CHECK(CacheStats.PutBlobCount == 0);
+ // TotalRequestCount covers both full-blob cache hits and partial-range cache hits.
+ CHECK(CacheStats.TotalRequestCount > 0);
+ }
+
+ SUBCASE("cache_no_populate_flag")
+ {
+ // Cache is provided but PopulateCache=false: blocks are downloaded to ImportCidStore but
+ // nothing should be written to the cache.
+ RemoteProjectStore::Result R = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/false);
+ CHECK(R.ErrorCode == 0);
+ CHECK(CacheStats.PutBlobCount == 0);
+ }
+
+ SUBCASE("mode_zencacheonly_cache_multirange")
+ {
+ // Pre-populate the cache via a plain import, then re-import with ZenCacheOnly +
+ // CacheMaxRangeCountPerRequest=128. With 100% of chunks needed, all blocks go to
+ // FullBlockIndexes and GetBuildBlob (full blob) is called from the cache.
+ // CacheMaxRangeCountPerRequest > 1 would route partial downloads through GetBuildBlobRanges
+ // if the analyser ever emits BlockRanges entries.
+ RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false);
+ CHECK(Populate.ErrorCode == 0);
+ ResetCacheStats();
+
+ RemoteProjectStore::Result R = DoImport(Cache.get(), EPartialBlockRequestMode::ZenCacheOnly, 0.1, 128u, 0.001, 128u, false, true);
+ CHECK(R.ErrorCode == 0);
+ CHECK(CacheStats.TotalRequestCount > 0);
+ }
+
+ SUBCASE("mode_zencacheonly_cache_singlerange")
+ {
+ // Pre-populate the cache, then re-import with ZenCacheOnly + CacheMaxRangeCountPerRequest=1.
+ // With 100% of chunks needed the analyser sends all blocks to FullBlockIndexes (full-block
+ // download path), which calls GetBuildBlob with no range offset — a full-blob cache hit.
+ // The single-range vs multi-range distinction only matters for the partial-block (BlockRanges)
+ // path, which is not reached when all chunks are needed.
+ RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false);
+ CHECK(Populate.ErrorCode == 0);
+ ResetCacheStats();
+
+ RemoteProjectStore::Result R = DoImport(Cache.get(), EPartialBlockRequestMode::ZenCacheOnly, 0.1, 128u, 0.001, 1u, false, true);
+ CHECK(R.ErrorCode == 0);
+ CHECK(CacheStats.TotalRequestCount > 0);
+ }
+
+ SUBCASE("mode_all_cache_and_cloud_multirange")
+ {
+ // Pre-populate cache; All mode uses multi-range for both the cache and cloud paths.
+ RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false);
+ CHECK(Populate.ErrorCode == 0);
+ ResetCacheStats();
+
+ RemoteProjectStore::Result R = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/true);
+ CHECK(R.ErrorCode == 0);
+ CHECK(CacheStats.TotalRequestCount > 0);
+ }
+
+ SUBCASE("partial_block_cloud_multirange")
+ {
+ // Export store with 6 × 512 KB chunks packed into one ~3 MB block.
+ ScopedTemporaryDirectory PartialExportDir;
+ std::shared_ptr<RemoteProjectStore> PartialRemoteStore;
+ RemoteProjectStore::Result ExportR =
+ SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore);
+ REQUIRE(ExportR.ErrorCode == 0);
+
+ // Seeding even-indexed chunks (0, 2, 4) leaves odd ones (1, 3, 5) absent in
+ // ImportCidStore. Three non-adjacent needed positions → three BlockRangeDescriptors.
+ IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u);
+ CHECK(BlockHash != IoHash::Zero);
+ SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash);
+
+ // StoreMaxRangeCountPerRequest=128 → all three ranges sent in one LoadAttachmentRanges call.
+ Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_multi_{}", OpJobIndex++), {});
+ RemoteProjectStore::Result R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ .RemoteStore = *PartialRemoteStore,
+ .OptionalCache = nullptr,
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *PartialOplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::All,
+ .PopulateCache = false,
+ .StoreLatencySec = 0.001,
+ .StoreMaxRangeCountPerRequest = 128u,
+ .CacheLatencySec = -1.0,
+ .CacheMaxRangeCountPerRequest = 0u,
+ .OptionalJobContext = &OpJobContext});
+ CHECK(R.ErrorCode == 0);
+ }
+
+ SUBCASE("partial_block_cloud_singlerange")
+ {
+ // Same block layout as partial_block_cloud_multirange but StoreMaxRangeCountPerRequest=1.
+ // DownloadPartialBlock issues one LoadAttachmentRanges call per range.
+ ScopedTemporaryDirectory PartialExportDir;
+ std::shared_ptr<RemoteProjectStore> PartialRemoteStore;
+ RemoteProjectStore::Result ExportR =
+ SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore);
+ REQUIRE(ExportR.ErrorCode == 0);
+
+ IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u);
+ CHECK(BlockHash != IoHash::Zero);
+ SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash);
+
+ Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_single_{}", OpJobIndex++), {});
+ RemoteProjectStore::Result R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ .RemoteStore = *PartialRemoteStore,
+ .OptionalCache = nullptr,
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *PartialOplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::All,
+ .PopulateCache = false,
+ .StoreLatencySec = 0.001,
+ .StoreMaxRangeCountPerRequest = 1u,
+ .CacheLatencySec = -1.0,
+ .CacheMaxRangeCountPerRequest = 0u,
+ .OptionalJobContext = &OpJobContext});
+ CHECK(R.ErrorCode == 0);
+ }
+
+ SUBCASE("partial_block_cache_multirange")
+ {
+ ScopedTemporaryDirectory PartialExportDir;
+ std::shared_ptr<RemoteProjectStore> PartialRemoteStore;
+ RemoteProjectStore::Result ExportR =
+ SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore);
+ REQUIRE(ExportR.ErrorCode == 0);
+
+ IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u);
+ CHECK(BlockHash != IoHash::Zero);
+
+ // Phase 1: ImportCidStore starts empty → full block download from remote → PutBuildBlob
+ // populates the cache.
+ {
+ Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p1_{}", OpJobIndex++), {});
+ RemoteProjectStore::Result Phase1R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ .RemoteStore = *PartialRemoteStore,
+ .OptionalCache = Cache.get(),
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *Phase1Oplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::All,
+ .PopulateCache = true,
+ .StoreLatencySec = 0.001,
+ .StoreMaxRangeCountPerRequest = 128u,
+ .CacheLatencySec = 0.001,
+ .CacheMaxRangeCountPerRequest = 128u,
+ .OptionalJobContext = &OpJobContext});
+ CHECK(Phase1R.ErrorCode == 0);
+ CHECK(CacheStats.PutBlobCount > 0);
+ }
+ ResetCacheStats();
+
+ // Phase 2: fresh CidStore with only even-indexed chunks seeded.
+ // HasAttachment returns false for odd chunks (1, 3, 5) → three BlockRangeDescriptors.
+ // Block is in cache from Phase 1 → cache partial path.
+ // CacheMaxRangeCountPerRequest=128 → SubRangeCount=3 > 1 → GetBuildBlobRanges.
+ GcManager Phase2Gc;
+ CidStore Phase2CidStore(Phase2Gc);
+ CidStoreConfiguration Phase2CidConfig = {.RootDirectory = TempDir.Path() / "partial_cas",
+ .TinyValueThreshold = 1024,
+ .HugeValueThreshold = 4096};
+ Phase2CidStore.Initialize(Phase2CidConfig);
+ SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash);
+
+ Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p2_{}", OpJobIndex++), {});
+ RemoteProjectStore::Result Phase2R = LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore,
+ .RemoteStore = *PartialRemoteStore,
+ .OptionalCache = Cache.get(),
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *Phase2Oplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::ZenCacheOnly,
+ .PopulateCache = false,
+ .StoreLatencySec = 0.001,
+ .StoreMaxRangeCountPerRequest = 128u,
+ .CacheLatencySec = 0.001,
+ .CacheMaxRangeCountPerRequest = 128u,
+ .OptionalJobContext = &OpJobContext});
+ CHECK(Phase2R.ErrorCode == 0);
+ CHECK(CacheStats.TotalRequestCount > 0);
+ }
+
+ SUBCASE("partial_block_cache_singlerange")
+ {
+ ScopedTemporaryDirectory PartialExportDir;
+ std::shared_ptr<RemoteProjectStore> PartialRemoteStore;
+ RemoteProjectStore::Result ExportR =
+ SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore);
+ REQUIRE(ExportR.ErrorCode == 0);
+
+ IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u);
+ CHECK(BlockHash != IoHash::Zero);
+
+ // Phase 1: full block download from remote into cache.
+ {
+ Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p1_{}", OpJobIndex++), {});
+ RemoteProjectStore::Result Phase1R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ .RemoteStore = *PartialRemoteStore,
+ .OptionalCache = Cache.get(),
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *Phase1Oplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::All,
+ .PopulateCache = true,
+ .StoreLatencySec = 0.001,
+ .StoreMaxRangeCountPerRequest = 128u,
+ .CacheLatencySec = 0.001,
+ .CacheMaxRangeCountPerRequest = 128u,
+ .OptionalJobContext = &OpJobContext});
+ CHECK(Phase1R.ErrorCode == 0);
+ CHECK(CacheStats.PutBlobCount > 0);
+ }
+ ResetCacheStats();
+
+ // Phase 2: fresh CidStore with only even-indexed chunks seeded.
+ // CacheMaxRangeCountPerRequest=1 → SubRangeCount=Min(3,1)=1 → GetBuildBlob with range
+ // offset (single-range legacy cache path), called once per needed chunk range.
+ GcManager Phase2Gc;
+ CidStore Phase2CidStore(Phase2Gc);
+ CidStoreConfiguration Phase2CidConfig = {.RootDirectory = TempDir.Path() / "partial_cas_single",
+ .TinyValueThreshold = 1024,
+ .HugeValueThreshold = 4096};
+ Phase2CidStore.Initialize(Phase2CidConfig);
+ SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash);
+
+ Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p2_{}", OpJobIndex++), {});
+ RemoteProjectStore::Result Phase2R = LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore,
+ .RemoteStore = *PartialRemoteStore,
+ .OptionalCache = Cache.get(),
+ .CacheBuildId = CacheBuildId,
+ .Oplog = *Phase2Oplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = false,
+ .IgnoreMissingAttachments = false,
+ .CleanOplog = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::ZenCacheOnly,
+ .PopulateCache = false,
+ .StoreLatencySec = 0.001,
+ .StoreMaxRangeCountPerRequest = 128u,
+ .CacheLatencySec = 0.001,
+ .CacheMaxRangeCountPerRequest = 1u,
+ .OptionalJobContext = &OpJobContext});
+ CHECK(Phase2R.ErrorCode == 0);
+ CHECK(CacheStats.TotalRequestCount > 0);
+ }
+}
+
+TEST_SUITE_END();
+
#endif // ZEN_WITH_TESTS
void
diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp
index ab82edbef..115d6438d 100644
--- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp
@@ -159,7 +159,8 @@ public:
virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
{
- std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog);
+ LoadAttachmentsResult Result;
+ std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog);
CbObject Request;
{
@@ -187,7 +188,7 @@ public:
HttpClient::Response Response = m_Client.Post(LoadRequest, Request, HttpClient::Accept(ZenContentType::kCbPackage));
AddStats(Response);
- LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)};
+ Result = LoadAttachmentsResult{ConvertResult(Response)};
if (Result.ErrorCode)
{
Result.Reason = fmt::format("Failed fetching {} oplog attachments from {}/{}/{}. Reason: '{}'",
@@ -249,20 +250,49 @@ public:
return GetKnownBlocksResult{{.ErrorCode = static_cast<int>(HttpResponseCode::NoContent)}};
}
+ virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span<const IoHash> BlockHashes,
+ BuildStorageCache* OptionalCache,
+ const Oid& CacheBuildId) override
+ {
+ ZEN_UNUSED(BlockHashes, OptionalCache, CacheBuildId);
+ return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}};
+ }
+
virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
{
+ LoadAttachmentResult Result;
std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash);
HttpClient::Response Response =
m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary));
AddStats(Response);
- LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)};
- if (!Result.ErrorCode)
+ Result = LoadAttachmentResult{ConvertResult(Response)};
+ if (Result.ErrorCode)
{
- Result.Bytes = Response.ResponsePayload;
- Result.Bytes.MakeOwned();
+ Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'",
+ m_ProjectStoreUrl,
+ m_Project,
+ m_Oplog,
+ RawHash,
+ Result.Reason);
}
- if (!Result.ErrorCode)
+ Result.Bytes = Response.ResponsePayload;
+ Result.Bytes.MakeOwned();
+ return Result;
+ }
+
+ virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash,
+ std::span<const std::pair<uint64_t, uint64_t>> Ranges) override
+ {
+ ZEN_ASSERT(!Ranges.empty());
+ LoadAttachmentRangesResult Result;
+ std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash);
+ HttpClient::Response Response =
+ m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary));
+ AddStats(Response);
+
+ Result = LoadAttachmentRangesResult{ConvertResult(Response)};
+ if (Result.ErrorCode)
{
Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'",
m_ProjectStoreUrl,
@@ -271,11 +301,13 @@ public:
RawHash,
Result.Reason);
}
+ else
+ {
+ Result.Ranges = std::vector<std::pair<uint64_t, uint64_t>>(Ranges.begin(), Ranges.end());
+ }
return Result;
}
- virtual void Flush() override {}
-
private:
void AddStats(const HttpClient::Response& Result)
{
diff --git a/src/zenserver-test/buildstore-tests.cpp b/src/zenserver-test/buildstore-tests.cpp
index 02b308485..cf9b10896 100644
--- a/src/zenserver-test/buildstore-tests.cpp
+++ b/src/zenserver-test/buildstore-tests.cpp
@@ -27,6 +27,8 @@ namespace zen::tests {
using namespace std::literals;
+TEST_SUITE_BEGIN("server.buildstore");
+
TEST_CASE("buildstore.blobs")
{
std::filesystem::path SystemRootPath = TestEnv.CreateNewTestDir();
@@ -36,7 +38,8 @@ TEST_CASE("buildstore.blobs")
std::string_view Bucket = "bkt"sv;
Oid BuildId = Oid::NewOid();
- std::vector<IoHash> CompressedBlobsHashes;
+ std::vector<IoHash> CompressedBlobsHashes;
+ std::vector<uint64_t> CompressedBlobsSizes;
{
ZenServerInstance Instance(TestEnv);
@@ -51,6 +54,7 @@ TEST_CASE("buildstore.blobs")
IoBuffer Blob = CreateSemiRandomBlob(4711 + I * 7);
CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob)));
CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash());
+ CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize());
IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer();
Payload.SetContentType(ZenContentType::kCompressedBinary);
@@ -107,6 +111,7 @@ TEST_CASE("buildstore.blobs")
IoBuffer Blob = CreateSemiRandomBlob(5713 + I * 7);
CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob)));
CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash());
+ CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize());
IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer();
Payload.SetContentType(ZenContentType::kCompressedBinary);
@@ -141,6 +146,201 @@ TEST_CASE("buildstore.blobs")
CHECK(IoHash::HashBuffer(Decompressed) == RawHash);
}
}
+
+ {
+ // Single-range Get
+
+ ZenServerInstance Instance(TestEnv);
+
+ const uint16_t PortNumber =
+ Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath));
+ CHECK(PortNumber != 0);
+
+ HttpClient Client(Instance.GetBaseUri() + "/builds/");
+
+ {
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ uint64_t BlobSize = CompressedBlobsSizes.front();
+
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges = {{BlobSize / 16 * 1, BlobSize / 2}};
+
+ uint64_t RangeSizeSum = Ranges.front().second;
+
+ HttpClient::KeyValueMap Headers;
+
+ Headers.Entries.insert(
+ {"Range", fmt::format("bytes={}-{}", Ranges.front().first, Ranges.front().first + Ranges.front().second - 1)});
+
+ HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), Headers);
+ REQUIRE(Result);
+ IoBuffer Payload = Result.ResponsePayload;
+ CHECK_EQ(RangeSizeSum, Payload.GetSize());
+
+ HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ HttpClient::Accept(ZenContentType::kCompressedBinary));
+ REQUIRE(FullBlobResult);
+ MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Ranges.front().first, Ranges.front().second);
+ MemoryView RangeView = Payload.GetView();
+ CHECK(ActualRange.EqualBytes(RangeView));
+ }
+ }
+
+ {
+ // Single-range Post
+
+ ZenServerInstance Instance(TestEnv);
+
+ const uint16_t PortNumber =
+ Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath));
+ CHECK(PortNumber != 0);
+
+ HttpClient Client(Instance.GetBaseUri() + "/builds/");
+
+ {
+ uint64_t RangeSizeSum = 0;
+
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ uint64_t BlobSize = CompressedBlobsSizes.front();
+
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges = {{BlobSize / 16 * 1, BlobSize / 2}};
+
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ {
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ Writer.BeginObject();
+ {
+ Writer.AddInteger("offset"sv, Range.first);
+ Writer.AddInteger("length"sv, Range.second);
+ RangeSizeSum += Range.second;
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray(); // ranges
+
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ REQUIRE(Result);
+ IoBuffer Payload = Result.ResponsePayload;
+ REQUIRE(Payload.GetContentType() == ZenContentType::kCbPackage);
+
+ CbPackage ResponsePackage = ParsePackageMessage(Payload);
+ CbObjectView ResponseObject = ResponsePackage.GetObject();
+
+ CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView();
+ CHECK_EQ(RangeArray.Num(), Ranges.size());
+ size_t RangeOffset = 0;
+ for (CbFieldView View : RangeArray)
+ {
+ CbObjectView Range = View.AsObjectView();
+ CHECK_EQ(Range["offset"sv].AsUInt64(), Ranges[RangeOffset].first);
+ CHECK_EQ(Range["length"sv].AsUInt64(), Ranges[RangeOffset].second);
+ RangeOffset++;
+ }
+
+ const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash);
+ REQUIRE(DataAttachment);
+ SharedBuffer PayloadRanges = DataAttachment->AsBinary();
+ CHECK_EQ(RangeSizeSum, PayloadRanges.GetSize());
+
+ HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ HttpClient::Accept(ZenContentType::kCompressedBinary));
+ REQUIRE(FullBlobResult);
+
+ uint64_t Offset = 0;
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Range.first, Range.second);
+ MemoryView RangeView = PayloadRanges.GetView().Mid(Offset, Range.second);
+ CHECK(ActualRange.EqualBytes(RangeView));
+ Offset += Range.second;
+ }
+ }
+ }
+
+ {
+ // Multi-range
+
+ ZenServerInstance Instance(TestEnv);
+
+ const uint16_t PortNumber =
+ Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath));
+ CHECK(PortNumber != 0);
+
+ HttpClient Client(Instance.GetBaseUri() + "/builds/");
+
+ {
+ uint64_t RangeSizeSum = 0;
+
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ uint64_t BlobSize = CompressedBlobsSizes.front();
+
+ std::vector<std::pair<uint64_t, uint64_t>> Ranges = {
+ {BlobSize / 16 * 1, BlobSize / 20},
+ {BlobSize / 16 * 3, BlobSize / 32},
+ {BlobSize / 16 * 5, BlobSize / 16},
+ {BlobSize - BlobSize / 16, BlobSize / 16 - 1},
+ };
+
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ {
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ Writer.BeginObject();
+ {
+ Writer.AddInteger("offset"sv, Range.first);
+ Writer.AddInteger("length"sv, Range.second);
+ RangeSizeSum += Range.second;
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray(); // ranges
+
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ REQUIRE(Result);
+ IoBuffer Payload = Result.ResponsePayload;
+ REQUIRE(Payload.GetContentType() == ZenContentType::kCbPackage);
+
+ CbPackage ResponsePackage = ParsePackageMessage(Payload);
+ CbObjectView ResponseObject = ResponsePackage.GetObject();
+
+ CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView();
+ CHECK_EQ(RangeArray.Num(), Ranges.size());
+ size_t RangeOffset = 0;
+ for (CbFieldView View : RangeArray)
+ {
+ CbObjectView Range = View.AsObjectView();
+ CHECK_EQ(Range["offset"sv].AsUInt64(), Ranges[RangeOffset].first);
+ CHECK_EQ(Range["length"sv].AsUInt64(), Ranges[RangeOffset].second);
+ RangeOffset++;
+ }
+
+ const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash);
+ REQUIRE(DataAttachment);
+ SharedBuffer PayloadRanges = DataAttachment->AsBinary();
+ CHECK_EQ(RangeSizeSum, PayloadRanges.GetSize());
+
+ HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ HttpClient::Accept(ZenContentType::kCompressedBinary));
+ REQUIRE(FullBlobResult);
+
+ uint64_t Offset = 0;
+ for (const std::pair<uint64_t, uint64_t>& Range : Ranges)
+ {
+ MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Range.first, Range.second);
+ MemoryView RangeView = PayloadRanges.GetView().Mid(Offset, Range.second);
+ CHECK(ActualRange.EqualBytes(RangeView));
+ Offset += Range.second;
+ }
+ }
+ }
}
namespace {
@@ -191,7 +391,7 @@ TEST_CASE("buildstore.metadata")
HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/getBlobMetadata", Namespace, Bucket, BuildId),
Payload,
HttpClient::Accept(ZenContentType::kCbObject));
- CHECK(Result);
+ REQUIRE(Result);
std::vector<CbObject> ResultMetadatas;
@@ -372,7 +572,7 @@ TEST_CASE("buildstore.cache")
{
std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes);
- CHECK(Exists.size() == BlobHashes.size());
+ REQUIRE(Exists.size() == BlobHashes.size());
for (size_t I = 0; I < BlobCount; I++)
{
CHECK(Exists[I].HasBody);
@@ -411,7 +611,7 @@ TEST_CASE("buildstore.cache")
{
std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes);
- CHECK(Exists.size() == BlobHashes.size());
+ REQUIRE(Exists.size() == BlobHashes.size());
for (size_t I = 0; I < BlobCount; I++)
{
CHECK(Exists[I].HasBody);
@@ -419,7 +619,7 @@ TEST_CASE("buildstore.cache")
}
std::vector<CbObject> FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes);
- CHECK_EQ(BlobCount, FetchedMetadatas.size());
+ REQUIRE_EQ(BlobCount, FetchedMetadatas.size());
for (size_t I = 0; I < BlobCount; I++)
{
@@ -440,7 +640,7 @@ TEST_CASE("buildstore.cache")
{
std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes);
- CHECK(Exists.size() == BlobHashes.size());
+ REQUIRE(Exists.size() == BlobHashes.size());
for (size_t I = 0; I < BlobCount * 2; I++)
{
CHECK(Exists[I].HasBody);
@@ -451,7 +651,7 @@ TEST_CASE("buildstore.cache")
CHECK_EQ(BlobCount, MetaDatas.size());
std::vector<CbObject> FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes);
- CHECK_EQ(BlobCount, FetchedMetadatas.size());
+ REQUIRE_EQ(BlobCount, FetchedMetadatas.size());
for (size_t I = 0; I < BlobCount; I++)
{
@@ -474,7 +674,7 @@ TEST_CASE("buildstore.cache")
CreateZenBuildStorageCache(Client, Stats, Namespace, Bucket, TempDir, GetTinyWorkerPool(EWorkloadType::Background)));
std::vector<BuildStorageCache::BlobExistsResult> Exists = Cache->BlobsExists(BuildId, BlobHashes);
- CHECK(Exists.size() == BlobHashes.size());
+ REQUIRE(Exists.size() == BlobHashes.size());
for (size_t I = 0; I < BlobCount * 2; I++)
{
CHECK(Exists[I].HasBody);
@@ -493,7 +693,7 @@ TEST_CASE("buildstore.cache")
CHECK_EQ(BlobCount, MetaDatas.size());
std::vector<CbObject> FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes);
- CHECK_EQ(BlobCount, FetchedMetadatas.size());
+ REQUIRE_EQ(BlobCount, FetchedMetadatas.size());
for (size_t I = 0; I < BlobCount; I++)
{
@@ -502,5 +702,7 @@ TEST_CASE("buildstore.cache")
}
}
+TEST_SUITE_END();
+
} // namespace zen::tests
#endif
diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp
index 0272d3797..334dd04ab 100644
--- a/src/zenserver-test/cache-tests.cpp
+++ b/src/zenserver-test/cache-tests.cpp
@@ -23,6 +23,8 @@
namespace zen::tests {
+TEST_SUITE_BEGIN("server.cache");
+
TEST_CASE("zcache.basic")
{
using namespace std::literals;
@@ -145,7 +147,7 @@ TEST_CASE("zcache.cbpackage")
for (const zen::CbAttachment& LhsAttachment : LhsAttachments)
{
const zen::CbAttachment* RhsAttachment = Rhs.FindAttachment(LhsAttachment.GetHash());
- CHECK(RhsAttachment);
+ REQUIRE(RhsAttachment);
zen::SharedBuffer LhsBuffer = LhsAttachment.AsCompressedBinary().Decompress();
CHECK(!LhsBuffer.IsNull());
@@ -1373,14 +1375,8 @@ TEST_CASE("zcache.rpc")
}
}
-TEST_CASE("zcache.failing.upstream")
+TEST_CASE("zcache.failing.upstream" * doctest::skip())
{
- // This is an exploratory test that takes a long time to run, so lets skip it by default
- if (true)
- {
- return;
- }
-
using namespace std::literals;
using namespace utils;
@@ -2669,6 +2665,8 @@ TEST_CASE("zcache.batchoperations")
}
}
+TEST_SUITE_END();
+
} // namespace zen::tests
#endif
diff --git a/src/zenserver-test/cacherequests.cpp b/src/zenserver-test/cacherequests.cpp
index 46339aebb..f5302a359 100644
--- a/src/zenserver-test/cacherequests.cpp
+++ b/src/zenserver-test/cacherequests.cpp
@@ -1037,6 +1037,8 @@ namespace zen { namespace cacherequests {
static CompressedBuffer MakeCompressedBuffer(size_t Size) { return CompressedBuffer::Compress(SharedBuffer(IoBuffer(Size))); };
+ TEST_SUITE_BEGIN("server.cacherequests");
+
TEST_CASE("cacherequests.put.cache.records")
{
PutCacheRecordsRequest EmptyRequest;
@@ -1458,5 +1460,7 @@ namespace zen { namespace cacherequests {
"!default!",
Invalid));
}
+
+ TEST_SUITE_END();
#endif
}} // namespace zen::cacherequests
diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp
new file mode 100644
index 000000000..c90ac5d8b
--- /dev/null
+++ b/src/zenserver-test/compute-tests.cpp
@@ -0,0 +1,1700 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/zencore.h>
+
+#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES
+
+# include <zenbase/zenbase.h>
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/filesystem.h>
+# include <zencore/guid.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+# include <zenhttp/httpclient.h>
+# include <zenhttp/httpserver.h>
+# include <zencompute/computeservice.h>
+# include <zenstore/zenstore.h>
+# include <zenutil/zenserverprocess.h>
+
+# include "zenserver-test.h"
+
+# include <thread>
+
+namespace zen::tests::compute {
+
+using namespace std::literals;
+
+// BuildSystemVersion and function version GUIDs matching zentest-appstub
+static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969";
+static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313";
+static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888";
+
+// In-memory implementation of ChunkResolver for test use.
+// Stores compressed data keyed by decompressed content hash.
+class InMemoryChunkResolver : public ChunkResolver
+{
+public:
+ IoBuffer FindChunkByCid(const IoHash& DecompressedId) override
+ {
+ auto It = m_Chunks.find(DecompressedId);
+ if (It != m_Chunks.end())
+ {
+ return It->second;
+ }
+ return {};
+ }
+
+ void AddChunk(const IoHash& DecompressedId, IoBuffer Data) { m_Chunks[DecompressedId] = std::move(Data); }
+
+private:
+ std::unordered_map<IoHash, IoBuffer> m_Chunks;
+};
+
+// Read, compress, and register zentest-appstub as a worker.
+// Returns the WorkerId (hash of the worker package object).
+static IoHash
+RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env)
+{
+ std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL);
+
+ FileContents AppStubData = zen::ReadFile(AppStubPath);
+ REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message()));
+
+ IoBuffer AppStubBuffer = AppStubData.Flatten();
+
+ CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash();
+ const uint64_t AppStubRawSize = AppStubBuffer.Size();
+
+ CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash);
+
+ CbObjectWriter WorkerWriter;
+ WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion);
+ WorkerWriter << "path"sv
+ << "zentest-appstub"sv;
+
+ WorkerWriter.BeginArray("executables"sv);
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "zentest-appstub"sv;
+ WorkerWriter.AddAttachment("hash"sv, AppStubAttachment);
+ WorkerWriter << "size"sv << AppStubRawSize;
+ WorkerWriter.EndObject();
+ WorkerWriter.EndArray();
+
+ WorkerWriter.BeginArray("functions"sv);
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Rot13"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kRot13Version);
+ WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Sleep"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kSleepVersion);
+ WorkerWriter.EndObject();
+ WorkerWriter.EndArray();
+
+ CbPackage WorkerPackage;
+ WorkerPackage.SetObject(WorkerWriter.Save());
+ WorkerPackage.AddAttachment(AppStubAttachment);
+
+ const IoHash WorkerId = WorkerPackage.GetObjectHash();
+
+ const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString());
+ HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage));
+ REQUIRE_MESSAGE(RegisterResp,
+ fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText()));
+
+ return WorkerId;
+}
+
+// Build a Rot13 action CbPackage for the given input string.
+static CbPackage
+BuildRot13ActionPackage(std::string_view Input)
+{
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Input.size();
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Rot13"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionWriter.Save());
+ ActionPackage.AddAttachment(InputAttachment);
+
+ return ActionPackage;
+}
+
+// Build a Sleep action CbPackage. The worker sleeps for SleepTimeMs before returning its input.
+static CbPackage
+BuildSleepActionPackage(std::string_view Input, uint64_t SleepTimeMs)
+{
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Input.size();
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Sleep"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+ ActionWriter.BeginObject("Constants"sv);
+ ActionWriter << "SleepTimeMs"sv << SleepTimeMs;
+ ActionWriter.EndObject();
+
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionWriter.Save());
+ ActionPackage.AddAttachment(InputAttachment);
+
+ return ActionPackage;
+}
+
+// Build a Sleep action CbObject and populate the chunk resolver with the input attachment.
+static CbObject
+BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemoryChunkResolver& Resolver)
+{
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Input.size();
+
+ Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer());
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Sleep"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+ ActionWriter.BeginObject("Constants"sv);
+ ActionWriter << "SleepTimeMs"sv << SleepTimeMs;
+ ActionWriter.EndObject();
+
+ return ActionWriter.Save();
+}
+
+static HttpClient::Response
+PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000)
+{
+ HttpClient::Response Resp;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < TimeoutMs)
+ {
+ Resp = Client.Get(ResultUrl);
+
+ if (Resp.StatusCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+
+ Sleep(100);
+ }
+
+ return Resp;
+}
+
+static bool
+PollForLsnInCompleted(HttpClient& Client, const std::string& CompletedUrl, int Lsn, uint64_t TimeoutMs = 30'000)
+{
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < TimeoutMs)
+ {
+ HttpClient::Response Resp = Client.Get(CompletedUrl);
+
+ if (Resp)
+ {
+ for (auto& Item : Resp.AsObject()["completed"sv])
+ {
+ if (Item.AsInt32() == Lsn)
+ {
+ return true;
+ }
+ }
+ }
+
+ Sleep(100);
+ }
+
+ return false;
+}
+
+static std::string
+GetRot13Output(const CbPackage& ResultPackage)
+{
+ CbObject ResultObj = ResultPackage.GetObject();
+
+ IoHash OutputHash;
+ CbFieldView ValuesField = ResultObj["Values"sv];
+
+ if (CbFieldViewIterator It = begin(ValuesField); It.HasValue())
+ {
+ OutputHash = (*It).AsObjectView()["RawHash"sv].AsHash();
+ }
+
+ REQUIRE_MESSAGE(OutputHash != IoHash::Zero, "Expected non-zero output hash in result Values array");
+
+ const CbAttachment* OutputAttachment = ResultPackage.FindAttachment(OutputHash);
+ REQUIRE_MESSAGE(OutputAttachment != nullptr, "Output attachment not found in result package");
+
+ CompressedBuffer OutputCompressed = OutputAttachment->AsCompressedBinary();
+ SharedBuffer OutputData = OutputCompressed.Decompress();
+
+ return std::string(static_cast<const char*>(OutputData.GetData()), OutputData.GetSize());
+}
+
+// Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response.
+class MockOrchestratorService : public HttpService
+{
+public:
+ MockOrchestratorService()
+ {
+ // Initialize with empty worker list
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("workers"sv);
+ Cbo.EndArray();
+ m_WorkerList = Cbo.Save();
+ }
+
+ const char* BaseUri() const override { return "/orch/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ if (Request.RequestVerb() == HttpVerb::kGet && Request.RelativeUri() == "agents"sv)
+ {
+ RwLock::SharedLockScope Lock(m_Lock);
+ Request.WriteResponse(HttpResponseCode::OK, m_WorkerList);
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ void SetWorkerList(CbObject WorkerList)
+ {
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+ m_WorkerList = std::move(WorkerList);
+ }
+
+private:
+ RwLock m_Lock;
+ CbObject m_WorkerList;
+};
+
+// Manages in-process ASIO HTTP server lifecycle for mock orchestrator.
+struct MockOrchestratorFixture
+{
+ MockOrchestratorService Service;
+ ScopedTemporaryDirectory TmpDir;
+ Ref<HttpServer> Server;
+ std::thread ServerThread;
+ uint16_t Port = 0;
+
+ MockOrchestratorFixture()
+ {
+ HttpServerConfig Config;
+ Config.ServerClass = "asio";
+ Config.ForceLoopback = true;
+ Server = CreateHttpServer(Config);
+ Server->RegisterService(Service);
+ Port = static_cast<uint16_t>(Server->Initialize(TestEnv.GetNewPortNumber(), TmpDir.Path()));
+ ZEN_ASSERT(Port != 0);
+ ServerThread = std::thread([this]() { Server->Run(false); });
+ }
+
+ ~MockOrchestratorFixture()
+ {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ }
+
+ std::string GetEndpoint() const { return fmt::format("http://localhost:{}", Port); }
+};
+
+// Build the CbObject response for /orch/agents matching the format UpdateCoordinatorState expects.
+static CbObject
+BuildAgentListResponse(std::initializer_list<std::pair<std::string_view, std::string_view>> Workers)
+{
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("workers"sv);
+ for (const auto& [Id, Uri] : Workers)
+ {
+ Cbo.BeginObject();
+ Cbo << "id"sv << Id;
+ Cbo << "uri"sv << Uri;
+ Cbo << "hostname"sv
+ << "localhost"sv;
+ Cbo << "reachable"sv << true;
+ Cbo << "dt"sv << uint64_t(0);
+ Cbo.EndObject();
+ }
+ Cbo.EndArray();
+ return Cbo.Save();
+}
+
+// Build the worker CbPackage for zentest-appstub AND populate the chunk resolver.
+// This is the same logic as RegisterWorker() but returns the package instead of POSTing it.
+static CbPackage
+BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver)
+{
+ std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL);
+
+ FileContents AppStubData = zen::ReadFile(AppStubPath);
+ REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message()));
+
+ IoBuffer AppStubBuffer = AppStubData.Flatten();
+
+ CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash();
+ const uint64_t AppStubRawSize = AppStubBuffer.Size();
+
+ // Store compressed data in chunk resolver for when the remote runner needs it
+ Resolver.AddChunk(AppStubRawHash, AppStubCompressed.GetCompressed().Flatten().AsIoBuffer());
+
+ CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash);
+
+ CbObjectWriter WorkerWriter;
+ WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion);
+ WorkerWriter << "path"sv
+ << "zentest-appstub"sv;
+
+ WorkerWriter.BeginArray("executables"sv);
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "zentest-appstub"sv;
+ WorkerWriter.AddAttachment("hash"sv, AppStubAttachment);
+ WorkerWriter << "size"sv << AppStubRawSize;
+ WorkerWriter.EndObject();
+ WorkerWriter.EndArray();
+
+ WorkerWriter.BeginArray("functions"sv);
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Rot13"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kRot13Version);
+ WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Sleep"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kSleepVersion);
+ WorkerWriter.EndObject();
+ WorkerWriter.EndArray();
+
+ CbPackage WorkerPackage;
+ WorkerPackage.SetObject(WorkerWriter.Save());
+ WorkerPackage.AddAttachment(AppStubAttachment);
+
+ return WorkerPackage;
+}
+
+// Build a Rot13 action CbObject (not CbPackage) and populate the chunk resolver with the input attachment.
+static CbObject
+BuildRot13ActionForSession(std::string_view Input, InMemoryChunkResolver& Resolver)
+{
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Input.size();
+
+ // Store compressed data in chunk resolver
+ Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer());
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Rot13"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+
+ return ActionWriter.Save();
+}
+
+TEST_SUITE_BEGIN("server.function");
+
+TEST_CASE("function.rot13")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Submit action via legacy /jobs/{worker} endpoint
+ const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission");
+
+ // Poll for result via legacy /jobs/{lsn} endpoint
+ const std::string ResultUrl = fmt::format("/jobs/{}", Lsn);
+ HttpClient::Response ResultResp = PollForResult(Client, ResultUrl);
+ REQUIRE_MESSAGE(
+ ResultResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput()));
+
+ // Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
+ CbPackage ResultPackage = ResultResp.AsPackage();
+ REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Action failed (empty result package)\nServer log:\n{}", Instance.GetLogOutput()));
+
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+}
+
+TEST_CASE("function.workers")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ // Before registration, GET /workers should return an empty list
+ HttpClient::Response EmptyListResp = Client.Get("/workers"sv);
+ REQUIRE_MESSAGE(EmptyListResp, "Failed to list workers before registration");
+ CHECK_EQ(EmptyListResp.AsObject()["workers"sv].AsArrayView().Num(), 0);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // GET /workers — the registered worker should appear in the listing
+ HttpClient::Response ListResp = Client.Get("/workers"sv);
+ REQUIRE_MESSAGE(ListResp, "Failed to list workers after registration");
+
+ bool WorkerFound = false;
+ for (auto& Item : ListResp.AsObject()["workers"sv])
+ {
+ if (Item.AsHash() == WorkerId)
+ {
+ WorkerFound = true;
+ break;
+ }
+ }
+
+ REQUIRE_MESSAGE(WorkerFound, fmt::format("Worker {} not found in worker listing", WorkerId.ToHexString()));
+
+ // GET /workers/{worker} — descriptor should match what was registered
+ const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString());
+ HttpClient::Response DescResp = Client.Get(WorkerUrl);
+ REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode)));
+
+ CbObject Desc = DescResp.AsObject();
+ CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion));
+ CHECK_EQ(Desc["path"sv].AsString(), "zentest-appstub"sv);
+
+ bool Rot13Found = false;
+ bool SleepFound = false;
+ for (auto& Item : Desc["functions"sv])
+ {
+ std::string_view Name = Item.AsObjectView()["name"sv].AsString();
+ if (Name == "Rot13"sv)
+ {
+ CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kRot13Version));
+ Rot13Found = true;
+ }
+ else if (Name == "Sleep"sv)
+ {
+ CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kSleepVersion));
+ SleepFound = true;
+ }
+ }
+
+ CHECK_MESSAGE(Rot13Found, "Rot13 function not found in worker descriptor");
+ CHECK_MESSAGE(SleepFound, "Sleep function not found in worker descriptor");
+
+ // GET /workers/{unknown} — should return 404
+ const std::string UnknownUrl = fmt::format("/workers/{}", IoHash::Zero.ToHexString());
+ HttpClient::Response NotFoundResp = Client.Get(UnknownUrl);
+ CHECK_EQ(NotFoundResp.StatusCode, HttpResponseCode::NotFound);
+}
+
+TEST_CASE("function.queues.lifecycle")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue
+ HttpClient::Response CreateResp = Client.Post("/queues"sv);
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation");
+
+ // Verify the queue appears in the listing
+ HttpClient::Response ListResp = Client.Get("/queues"sv);
+ REQUIRE_MESSAGE(ListResp, "Failed to list queues");
+
+ bool QueueFound = false;
+ for (auto& Item : ListResp.AsObject()["queues"sv])
+ {
+ if (Item.AsObjectView()["queue_id"sv].AsInt32() == QueueId)
+ {
+ QueueFound = true;
+ break;
+ }
+ }
+
+ REQUIRE_MESSAGE(QueueFound, fmt::format("Queue {} not found in queue listing", QueueId));
+
+ // Submit action via queue-scoped endpoint
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission");
+
+ // Poll for completion via queue-scoped /completed endpoint
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Retrieve result via queue-scoped /jobs/{lsn} endpoint
+ const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
+ HttpClient::Response ResultResp = Client.Get(ResultUrl);
+ REQUIRE_MESSAGE(
+ ResultResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput()));
+
+ // Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
+ CbPackage ResultPackage = ResultResp.AsPackage();
+ REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
+
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+
+ // Verify queue status reflects completion
+ const std::string StatusUrl = fmt::format("/queues/{}", QueueId);
+ HttpClient::Response StatusResp = Client.Get(StatusUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0);
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 0);
+ CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "active");
+}
+
+TEST_CASE("function.queues.cancel")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue
+ HttpClient::Response CreateResp = Client.Post("/queues"sv);
+ REQUIRE_MESSAGE(CreateResp, "Queue creation failed");
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation");
+
+ // Submit a job
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+
+ // Cancel the queue
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+ HttpClient::Response CancelResp = Client.Delete(QueueUrl);
+ REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
+ fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+
+ // Verify queue status shows cancelled
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled");
+}
+
+TEST_CASE("function.queues.remote")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a remote queue — response includes both an integer queue_id and an OID queue_token
+ HttpClient::Response CreateResp = Client.Post("/queues/remote"sv);
+ REQUIRE_MESSAGE(CreateResp,
+ fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+
+ CbObject CreateObj = CreateResp.AsObject();
+ const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString());
+ REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation");
+
+ // All subsequent requests use the opaque token in place of the integer queue id
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission");
+
+ // Poll for completion via the token-addressed /completed endpoint
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken);
+ REQUIRE_MESSAGE(
+ PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in remote queue completed list within timeout\nServer log:\n{}", Lsn, Instance.GetLogOutput()));
+
+ // Retrieve result via the token-addressed /jobs/{lsn} endpoint
+ const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, Lsn);
+ HttpClient::Response ResultResp = Client.Get(ResultUrl);
+ REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}",
+ int(ResultResp.StatusCode),
+ Instance.GetLogOutput()));
+
+ // Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
+ CbPackage ResultPackage = ResultResp.AsPackage();
+ REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
+
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+}
+
+TEST_CASE("function.queues.cancel_running")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue
+ HttpClient::Response CreateResp = Client.Post("/queues"sv);
+ REQUIRE_MESSAGE(CreateResp, "Queue creation failed");
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation");
+
+ // Submit a Sleep job long enough that it will still be running when we cancel
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission");
+
+ // Wait for the worker process to start executing before cancelling
+ Sleep(1'000);
+
+ // Cancel the queue, which should interrupt the running Sleep job
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+ HttpClient::Response CancelResp = Client.Delete(QueueUrl);
+ REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
+ fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+
+ // The cancelled job should appear in the /completed endpoint once the process exits
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list after cancel\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify the queue reflects one cancelled action
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled");
+ CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+}
+
+TEST_CASE("function.queues.remote_cancel")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a remote queue to obtain an OID token for token-addressed cancellation
+ HttpClient::Response CreateResp = Client.Post("/queues/remote"sv);
+ REQUIRE_MESSAGE(CreateResp,
+ fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+
+ const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString());
+ REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation");
+
+ // Submit a long-running Sleep job via the token-addressed endpoint
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission");
+
+ // Wait for the worker process to start executing before cancelling
+ Sleep(1'000);
+
+ // Cancel the queue via its OID token
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueToken);
+ HttpClient::Response CancelResp = Client.Delete(QueueUrl);
+ REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
+ fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+
+ // The cancelled job should appear in the token-addressed /completed endpoint
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken);
+ REQUIRE_MESSAGE(
+ PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in remote queue completed list after cancel\nServer log:\n{}", Lsn, Instance.GetLogOutput()));
+
+ // Verify the queue status reflects the cancellation
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get remote queue status after cancel");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled");
+ CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+}
+
+TEST_CASE("function.queues.drain")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue
+ HttpClient::Response CreateResp = Client.Post("/queues"sv);
+ REQUIRE_MESSAGE(CreateResp, "Queue creation failed");
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+
+ // Submit a long-running job so we can verify it completes even after drain
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000));
+ REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode)));
+ const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32();
+
+ // Drain the queue
+ const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId);
+ HttpClient::Response DrainResp = Client.Post(DrainUrl);
+ REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText()));
+ CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining");
+
+ // Second submission should be rejected with 424
+ HttpClient::Response Submit2 = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv));
+ CHECK_EQ(Submit2.StatusCode, HttpResponseCode::FailedDependency);
+ CHECK_EQ(std::string(Submit2.AsObject()["error"sv].AsString()), "queue is draining");
+
+ // First job should still complete
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn1),
+ fmt::format("LSN {} did not complete after drain\nServer log:\n{}", Lsn1, Instance.GetLogOutput()));
+
+ // Queue status should show draining + complete
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "draining");
+ CHECK(QueueStatus["is_complete"sv].AsBool());
+}
+
+TEST_CASE("function.priority")
+{
+ // Spawn server with max-actions=1 to guarantee serialized action execution,
+ // which lets us deterministically verify that higher-priority pending jobs
+ // are scheduled before lower-priority ones.
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1");
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue for all test jobs
+ HttpClient::Response CreateResp = Client.Post("/queues"sv);
+ REQUIRE_MESSAGE(CreateResp, "Queue creation failed");
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id");
+
+ // Submit a blocker Sleep job to occupy the single execution slot.
+ // Once the blocker is running, the scheduler must choose among the pending
+ // jobs by priority when the slot becomes free.
+ const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString());
+ HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000));
+ REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode)));
+
+ // Submit 3 low-priority Rot13 jobs
+ const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString());
+
+ HttpClient::Response LowResp1 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low1"sv));
+ REQUIRE_MESSAGE(LowResp1, "Low-priority job 1 submission failed");
+ const int LsnLow1 = LowResp1.AsObject()["lsn"sv].AsInt32();
+
+ HttpClient::Response LowResp2 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low2"sv));
+ REQUIRE_MESSAGE(LowResp2, "Low-priority job 2 submission failed");
+ const int LsnLow2 = LowResp2.AsObject()["lsn"sv].AsInt32();
+
+ HttpClient::Response LowResp3 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low3"sv));
+ REQUIRE_MESSAGE(LowResp3, "Low-priority job 3 submission failed");
+ const int LsnLow3 = LowResp3.AsObject()["lsn"sv].AsInt32();
+
+ // Submit 1 high-priority Rot13 job — should execute before the low-priority ones
+ const std::string HighJobUrl = fmt::format("/queues/{}/jobs/{}?priority=10", QueueId, WorkerId.ToHexString());
+ HttpClient::Response HighResp = Client.Post(HighJobUrl, BuildRot13ActionPackage("high"sv));
+ REQUIRE_MESSAGE(HighResp, "High-priority job submission failed");
+ const int LsnHigh = HighResp.AsObject()["lsn"sv].AsInt32();
+
+ // Wait for all 4 priority-test jobs to appear in the queue's completed list.
+ // This avoids any snapshot-timing race: by the time we compare timestamps, all
+ // jobs have already finished and their history entries are stable.
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+
+ {
+ bool AllCompleted = false;
+ Stopwatch WaitTimer;
+
+ while (!AllCompleted && WaitTimer.GetElapsedTimeMs() < 30'000)
+ {
+ HttpClient::Response Resp = Client.Get(CompletedUrl);
+
+ if (Resp)
+ {
+ bool FoundHigh = false;
+ bool FoundLow1 = false;
+ bool FoundLow2 = false;
+ bool FoundLow3 = false;
+
+ CbObject RespObj = Resp.AsObject();
+
+ for (auto& Item : RespObj["completed"sv])
+ {
+ const int Lsn = Item.AsInt32();
+ if (Lsn == LsnHigh)
+ {
+ FoundHigh = true;
+ }
+ else if (Lsn == LsnLow1)
+ {
+ FoundLow1 = true;
+ }
+ else if (Lsn == LsnLow2)
+ {
+ FoundLow2 = true;
+ }
+ else if (Lsn == LsnLow3)
+ {
+ FoundLow3 = true;
+ }
+ }
+
+ AllCompleted = FoundHigh && FoundLow1 && FoundLow2 && FoundLow3;
+ }
+
+ if (!AllCompleted)
+ {
+ Sleep(100);
+ }
+ }
+
+ REQUIRE_MESSAGE(
+ AllCompleted,
+ fmt::format(
+ "Not all priority test jobs completed within timeout (lsnHigh={} lsnLow1={} lsnLow2={} lsnLow3={})\nServer log:\n{}",
+ LsnHigh,
+ LsnLow1,
+ LsnLow2,
+ LsnLow3,
+ Instance.GetLogOutput()));
+ }
+
+ // Query the queue-scoped history to obtain the time_Completed timestamp for each
+ // job. The history endpoint records when each RunnerAction::State transition
+ // occurred, so time_Completed is the wall-clock tick at which the action finished.
+ // Using the queue-scoped endpoint avoids exposing history from other queues.
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ CbObject HistoryObj = HistoryResp.AsObject();
+
+ auto GetCompletedTimestamp = [&](int Lsn) -> uint64_t {
+ for (auto& Item : HistoryObj["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ return Item.AsObjectView()["time_Completed"sv].AsUInt64();
+ }
+ }
+ return 0;
+ };
+
+ const uint64_t TimeHigh = GetCompletedTimestamp(LsnHigh);
+ const uint64_t TimeLow1 = GetCompletedTimestamp(LsnLow1);
+ const uint64_t TimeLow2 = GetCompletedTimestamp(LsnLow2);
+ const uint64_t TimeLow3 = GetCompletedTimestamp(LsnLow3);
+
+ REQUIRE_MESSAGE(TimeHigh != 0, fmt::format("lsnHigh={} not found in action history", LsnHigh));
+ REQUIRE_MESSAGE(TimeLow1 != 0, fmt::format("lsnLow1={} not found in action history", LsnLow1));
+ REQUIRE_MESSAGE(TimeLow2 != 0, fmt::format("lsnLow2={} not found in action history", LsnLow2));
+ REQUIRE_MESSAGE(TimeLow3 != 0, fmt::format("lsnLow3={} not found in action history", LsnLow3));
+
+ // The high-priority job must have completed strictly before every low-priority job
+ CHECK_MESSAGE(TimeHigh < TimeLow1,
+ fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow1={} completed at t={} (expected later)",
+ LsnHigh,
+ TimeHigh,
+ LsnLow1,
+ TimeLow1));
+ CHECK_MESSAGE(TimeHigh < TimeLow2,
+ fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow2={} completed at t={} (expected later)",
+ LsnHigh,
+ TimeHigh,
+ LsnLow2,
+ TimeLow2));
+ CHECK_MESSAGE(TimeHigh < TimeLow3,
+ fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow3={} completed at t={} (expected later)",
+ LsnHigh,
+ TimeHigh,
+ LsnLow3,
+ TimeLow3));
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Remote worker synchronization tests
+//
+// These tests exercise the orchestrator discovery path where new compute
+// nodes appear over time and must receive previously registered workers
+// via SyncWorkersToRunner().
+
+TEST_CASE("function.remote.worker_sync_on_discovery")
+{
+ // Spawn real zenserver in compute mode
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t ServerPort = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(ServerPort != 0, Instance.GetLogOutput());
+
+ const std::string ServerUri = fmt::format("http://localhost:{}", ServerPort);
+
+ // Start mock orchestrator with empty worker list
+ MockOrchestratorFixture MockOrch;
+
+ // Create session infrastructure
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ // Register worker on session (stored locally, no runners yet)
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Update mock orchestrator to advertise the real server
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}}));
+
+ // Wait for scheduler to discover the runner (~5s throttle + margin)
+ Sleep(7'000);
+
+ // Submit Rot13 action via session
+ CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver);
+
+ zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Action enqueue failed");
+
+ // Poll for result
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(
+ ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput()));
+
+ REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
+
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.remote.late_runner_discovery")
+{
+ // Spawn first server
+ ZenServerInstance Instance1(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance1.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port1 = Instance1.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port1 != 0, Instance1.GetLogOutput());
+
+ const std::string ServerUri1 = fmt::format("http://localhost:{}", Port1);
+
+ // Start mock orchestrator advertising W1
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}}));
+
+ // Create session and register worker
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Wait for W1 discovery
+ Sleep(7'000);
+
+ // Baseline: submit Rot13 action and verify it completes on W1
+ {
+ CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver);
+
+ zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed");
+
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance1.GetLogOutput()));
+
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+ }
+
+ // Spawn second server
+ ZenServerInstance Instance2(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance2.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port2 = Instance2.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port2 != 0, Instance2.GetLogOutput());
+
+ const std::string ServerUri2 = fmt::format("http://localhost:{}", Port2);
+
+ // Update mock orchestrator to include both W1 and W2
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}}));
+
+ // Wait for W2 discovery
+ Sleep(7'000);
+
+ // Verify W2 received the worker by querying its /compute/workers endpoint directly
+ {
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2);
+ HttpClient Client(ComputeBaseUri);
+ HttpClient::Response ListResp = Client.Get("/workers"sv);
+ REQUIRE_MESSAGE(ListResp, "Failed to list workers on W2");
+
+ bool WorkerFound = false;
+ for (auto& Item : ListResp.AsObject()["workers"sv])
+ {
+ if (Item.AsHash() == WorkerPackage.GetObjectHash())
+ {
+ WorkerFound = true;
+ break;
+ }
+ }
+
+ REQUIRE_MESSAGE(WorkerFound,
+ fmt::format("Worker not found on W2 after discovery — SyncWorkersToRunner may have failed\nW2 log:\n{}",
+ Instance2.GetLogOutput()));
+ }
+
+ // Submit another action and verify it completes (could run on either W1 or W2)
+ {
+ CbObject ActionObj = BuildRot13ActionForSession("Second Test"sv, Resolver);
+
+ zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Second action enqueue failed");
+
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Second action did not complete in time\nW1 log:\n{}\nW2 log:\n{}",
+ Instance1.GetLogOutput(),
+ Instance2.GetLogOutput()));
+
+ // Rot13("Second Test") = "Frpbaq Grfg"
+ CHECK_EQ(GetRot13Output(ResultPackage), "Frpbaq Grfg"sv);
+ }
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.remote.queue_association")
+{
+ // Spawn real zenserver as a remote compute node
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput());
+
+ // Start mock orchestrator advertising the server
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}}));
+
+ // Create session infrastructure
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ // Register worker on session
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Wait for scheduler to discover the runner
+ Sleep(7'000);
+
+ // Create a local queue and submit action to it
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue");
+ const int QueueId = QueueResult.QueueId;
+
+ CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver);
+
+ zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Action enqueue to queue failed");
+
+ // Poll for result
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(
+ ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput()));
+
+ REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+
+ // Verify that a non-implicit remote queue was created on the compute node
+ HttpClient Client(Instance.GetBaseUri() + "/compute");
+
+ HttpClient::Response QueuesResp = Client.Get("/queues"sv);
+ REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server");
+
+ bool RemoteQueueFound = false;
+ for (auto& Item : QueuesResp.AsObject()["queues"sv])
+ {
+ if (!Item.AsObjectView()["implicit"sv].AsBool())
+ {
+ RemoteQueueFound = true;
+ break;
+ }
+ }
+
+ CHECK_MESSAGE(RemoteQueueFound, "Expected a non-implicit remote queue on the compute node");
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.remote.queue_cancel_propagation")
+{
+ // Spawn real zenserver as a remote compute node
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput());
+
+ // Start mock orchestrator advertising the server
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}}));
+
+ // Create session infrastructure
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ // Register worker on session
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Wait for scheduler to discover the runner
+ Sleep(7'000);
+
+ // Create a local queue and submit a long-running Sleep action
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue");
+ const int QueueId = QueueResult.QueueId;
+
+ CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver);
+
+ zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
+
+ // Wait for the action to start running on the remote
+ Sleep(2'000);
+
+ // Cancel the local queue — this should propagate to the remote
+ Session.CancelQueue(QueueId);
+
+ // Poll for the action to complete (as cancelled)
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ // Verify the local queue shows cancelled
+ auto QueueStatus = Session.GetQueueStatus(QueueId);
+ CHECK(QueueStatus.State == zen::compute::ComputeServiceSession::QueueState::Cancelled);
+
+ // Verify the remote queue on the compute node is also cancelled
+ HttpClient Client(Instance.GetBaseUri() + "/compute");
+
+ HttpClient::Response QueuesResp = Client.Get("/queues"sv);
+ REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server");
+
+ bool RemoteQueueCancelled = false;
+ for (auto& Item : QueuesResp.AsObject()["queues"sv])
+ {
+ if (!Item.AsObjectView()["implicit"sv].AsBool())
+ {
+ RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled";
+ break;
+ }
+ }
+
+ CHECK_MESSAGE(RemoteQueueCancelled, "Expected the remote queue to be cancelled");
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.abandon_running_http")
+{
+ // Spawn a real zenserver to execute a long-running action, then abandon via HTTP endpoint
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Create a queue and submit a long-running Sleep job
+ HttpClient::Response CreateResp = Client.Post("/queues"sv);
+ REQUIRE_MESSAGE(CreateResp, "Queue creation failed");
+
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id");
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode)));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN");
+
+ // Wait for the process to start running
+ Sleep(1'000);
+
+ // Verify the ready endpoint returns OK before abandon
+ {
+ HttpClient::Response ReadyResp = Client.Get("/ready"sv);
+ CHECK(ReadyResp.StatusCode == HttpResponseCode::OK);
+ }
+
+ // Trigger abandon via the HTTP endpoint
+ HttpClient::Response AbandonResp = Client.Post("/abandon"sv);
+ REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText()));
+
+ // Ready endpoint should now return 503
+ {
+ HttpClient::Response ReadyResp = Client.Get("/ready"sv);
+ CHECK(ReadyResp.StatusCode == HttpResponseCode::ServiceUnavailable);
+ }
+
+ // The abandoned action should appear in the completed endpoint once the process exits
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list after abandon\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ // Verify the queue reflects one abandoned action
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after abandon");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["abandoned_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+ CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0);
+
+ // Submitting new work should be rejected
+ HttpClient::Response RejectedResp = Client.Post(JobUrl, BuildRot13ActionPackage("rejected"sv));
+ CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state");
+}
+
+TEST_CASE("function.session.abandon_pending")
+{
+ // Create a session with no runners so actions stay pending
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Enqueue several actions — they will stay pending because there are no runners
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue");
+
+ CbObject ActionObj = BuildRot13ActionForSession("abandon-test"sv, Resolver);
+
+ auto Enqueue1 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0);
+ auto Enqueue2 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0);
+ auto Enqueue3 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(Enqueue1, "Failed to enqueue action 1");
+ REQUIRE_MESSAGE(Enqueue2, "Failed to enqueue action 2");
+ REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3");
+
+ // Transition to Abandoned — should mark all pending actions as Abandoned
+ bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
+ CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ CHECK(!Session.IsHealthy());
+
+ // Give the scheduler thread time to process the state changes
+ Sleep(2'000);
+
+ // All three actions should now be in the results map as abandoned
+ for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn})
+ {
+ CbPackage Result;
+ HttpResponseCode Code = Session.GetActionResult(Lsn, Result);
+ CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code)));
+ }
+
+ // Queue should show 0 active, 3 abandoned
+ auto Status = Session.GetQueueStatus(QueueResult.QueueId);
+ CHECK_EQ(Status.ActiveCount, 0);
+ CHECK_EQ(Status.AbandonedCount, 3);
+
+ // New actions should be rejected
+ auto Rejected = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0);
+ CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected in Abandoned state");
+
+ // Abandoned → Sunset should be valid
+ CHECK(Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Sunset));
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.session.abandon_running")
+{
+ // Spawn a real zenserver as a remote compute node
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput());
+
+ // Start mock orchestrator advertising the server
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}}));
+
+ // Create session infrastructure
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Wait for scheduler to discover the runner
+ Sleep(7'000);
+
+ // Create a queue and submit a long-running Sleep action
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue");
+ const int QueueId = QueueResult.QueueId;
+
+ CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver);
+
+ auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
+
+ // Wait for the action to start running on the remote
+ Sleep(2'000);
+
+ // Transition to Abandoned — should abandon the running action
+ bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
+ CHECK(!Session.IsHealthy());
+
+ // Poll for the action to complete (as abandoned)
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput()));
+
+ // Verify the queue shows abandoned, not completed
+ auto QueueStatus = Session.GetQueueStatus(QueueId);
+ CHECK_EQ(QueueStatus.ActiveCount, 0);
+ CHECK_EQ(QueueStatus.AbandonedCount, 1);
+ CHECK_EQ(QueueStatus.CompletedCount, 0);
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.remote.abandon_propagation")
+{
+ // Spawn real zenserver as a remote compute node
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput());
+
+ // Start mock orchestrator advertising the server
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}}));
+
+ // Create session infrastructure
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+
+ // Register worker on session
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Wait for scheduler to discover the runner
+ Sleep(7'000);
+
+ // Create a local queue and submit a long-running Sleep action
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue");
+ const int QueueId = QueueResult.QueueId;
+
+ CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver);
+
+ auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
+
+ // Wait for the action to start running on the remote
+ Sleep(2'000);
+
+ // Transition to Abandoned — should abandon the running action and propagate
+ bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
+
+ // Poll for the action to complete
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput()));
+
+ // Verify the local queue shows abandoned
+ auto QueueStatus = Session.GetQueueStatus(QueueId);
+ CHECK_EQ(QueueStatus.ActiveCount, 0);
+ CHECK_EQ(QueueStatus.AbandonedCount, 1);
+
+ // Session should not be healthy
+ CHECK(!Session.IsHealthy());
+
+ // The remote compute node should still be healthy (only the parent abandoned)
+ HttpClient RemoteClient(Instance.GetBaseUri() + "/compute");
+ HttpClient::Response ReadyResp = RemoteClient.Get("/ready"sv);
+ CHECK_MESSAGE(ReadyResp.StatusCode == HttpResponseCode::OK, "Remote compute node should still be healthy");
+
+ Session.Shutdown();
+}
+
+TEST_SUITE_END();
+
+} // namespace zen::tests::compute
+
+#endif
diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp
index 42a5dcae4..11531e30f 100644
--- a/src/zenserver-test/hub-tests.cpp
+++ b/src/zenserver-test/hub-tests.cpp
@@ -24,7 +24,7 @@ namespace zen::tests::hub {
using namespace std::literals;
-TEST_SUITE_BEGIN("hub.lifecycle");
+TEST_SUITE_BEGIN("server.hub");
TEST_CASE("hub.lifecycle.basic")
{
@@ -230,9 +230,7 @@ TEST_CASE("hub.lifecycle.children")
}
}
-TEST_SUITE_END();
-
-TEST_CASE("hub.consul.lifecycle")
+TEST_CASE("hub.consul.lifecycle" * doctest::skip())
{
zen::consul::ConsulProcess ConsulProc;
ConsulProc.SpawnConsulAgent();
@@ -248,5 +246,7 @@ TEST_CASE("hub.consul.lifecycle")
ConsulProc.StopConsulAgent();
}
+TEST_SUITE_END();
+
} // namespace zen::tests::hub
#endif
diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp
new file mode 100644
index 000000000..2e530ff92
--- /dev/null
+++ b/src/zenserver-test/logging-tests.cpp
@@ -0,0 +1,261 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/zencore.h>
+
+#if ZEN_WITH_TESTS
+
+# include "zenserver-test.h"
+
+# include <zencore/filesystem.h>
+# include <zencore/logging.h>
+# include <zencore/testing.h>
+# include <zenutil/zenserverprocess.h>
+
+namespace zen::tests {
+
+using namespace std::literals;
+
+TEST_SUITE_BEGIN("server.logging");
+
+//////////////////////////////////////////////////////////////////////////
+
+static bool
+LogContains(const std::string& Log, std::string_view Needle)
+{
+ return Log.find(Needle) != std::string::npos;
+}
+
+static std::string
+ReadFileToString(const std::filesystem::path& Path)
+{
+ FileContents Contents = ReadFile(Path);
+ if (Contents.ErrorCode)
+ {
+ return {};
+ }
+
+ IoBuffer Content = Contents.Flatten();
+ if (!Content)
+ {
+ return {};
+ }
+
+ return std::string(static_cast<const char*>(Content.Data()), Content.Size());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+// Verify that a log file is created at the default location (DataDir/logs/zenserver.log)
+// even without --abslog. The file must contain "server session id" (logged at INFO
+// to all registered loggers during init) and "log starting at" (emitted once a file
+// sink is first opened).
+TEST_CASE("logging.file.default")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ const std::filesystem::path DefaultLogFile = TestDir / "logs" / "zenserver.log";
+ CHECK_MESSAGE(std::filesystem::exists(DefaultLogFile), "Default log file was not created");
+ const std::string FileLog = ReadFileToString(DefaultLogFile);
+ CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog);
+ CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog);
+}
+
+// --quiet sets the console sink level to WARN. The formatted "[info] ..."
+// entry written by the default logger's console sink must therefore not appear
+// in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_*
+// macros — may still emit plain-text messages without a level marker, so we
+// check for the absence of the FullFormatter "[info]" prefix rather than the
+// message text itself.)
+TEST_CASE("logging.console.quiet")
+{
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--quiet");
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ const std::string Log = Instance.GetLogOutput();
+ CHECK_MESSAGE(!LogContains(Log, "[info] server session id"), Log);
+}
+
+// --noconsole removes the stdout sink entirely, so the captured console output
+// must not contain any log entries from the logging system.
+TEST_CASE("logging.console.disabled")
+{
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--noconsole");
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ const std::string Log = Instance.GetLogOutput();
+ CHECK_MESSAGE(!LogContains(Log, "server session id"), Log);
+}
+
+// --abslog <path> creates a rotating log file at the specified path.
+// The file must contain "server session id" (logged at INFO to all loggers
+// during init) and "log starting at" (emitted once a file sink is active).
+TEST_CASE("logging.file.basic")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const std::filesystem::path LogFile = TestDir / "test.log";
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+
+ const std::string LogArg = fmt::format("--abslog {}", LogFile.string());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created");
+ const std::string FileLog = ReadFileToString(LogFile);
+ CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog);
+ CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog);
+}
+
+// --abslog with a .json extension selects the JSON formatter.
+// Each log entry must be a JSON object containing at least the "message"
+// and "source" fields.
+TEST_CASE("logging.file.json")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const std::filesystem::path LogFile = TestDir / "test.json";
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+
+ const std::string LogArg = fmt::format("--abslog {}", LogFile.string());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created");
+ const std::string FileLog = ReadFileToString(LogFile);
+ CHECK_MESSAGE(LogContains(FileLog, "\"message\""), FileLog);
+ CHECK_MESSAGE(LogContains(FileLog, "\"source\": \"zenserver\""), FileLog);
+ CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog);
+}
+
+// --log-id <id> is automatically set to the server instance name in test mode.
+// The JSON formatter emits this value as the "id" field, so every entry in a
+// .json log file must carry a non-empty "id".
+TEST_CASE("logging.log_id")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const std::filesystem::path LogFile = TestDir / "test.json";
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+
+ const std::string LogArg = fmt::format("--abslog {}", LogFile.string());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created");
+ const std::string FileLog = ReadFileToString(LogFile);
+ // The JSON formatter writes the log-id as: "id": "<value>",
+ CHECK_MESSAGE(LogContains(FileLog, "\"id\": \""), FileLog);
+}
+
+// --log-warn <logger> raises the level threshold above INFO so that INFO messages
+// are filtered. "server session id" is broadcast at INFO to all loggers: it must
+// appear in the main file sink (default logger unaffected) but must NOT appear in
+// http.log where the http_requests logger now has a WARN threshold.
+TEST_CASE("logging.level.warn_suppresses_info")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const std::filesystem::path LogFile = TestDir / "test.log";
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+
+ const std::string LogArg = fmt::format("--abslog {} --log-warn http_requests", LogFile.string());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created");
+ const std::string FileLog = ReadFileToString(LogFile);
+ CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog);
+
+ const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log";
+ CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created");
+ const std::string HttpLog = ReadFileToString(HttpLogFile);
+ CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog);
+}
+
+// --log-info <logger> sets an explicit INFO threshold. The INFO "server session id"
+// broadcast must still land in http.log, confirming that INFO messages are not
+// filtered when the logger level is exactly INFO.
+TEST_CASE("logging.level.info_allows_info")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const std::filesystem::path LogFile = TestDir / "test.log";
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+
+ const std::string LogArg = fmt::format("--abslog {} --log-info http_requests", LogFile.string());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log";
+ CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created");
+ const std::string HttpLog = ReadFileToString(HttpLogFile);
+ CHECK_MESSAGE(LogContains(HttpLog, "server session id"), HttpLog);
+}
+
+// --log-off <logger> silences a named logger entirely.
+// "server session id" is broadcast at INFO to all registered loggers via
+// spdlog::apply_all during init. When the "http_requests" logger is set to
+// OFF its dedicated http.log file must not contain that message.
+// The main file sink (via --abslog) must be unaffected.
+TEST_CASE("logging.level.off_specific_logger")
+{
+ const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const std::filesystem::path LogFile = TestDir / "test.log";
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+
+ const std::string LogArg = fmt::format("--abslog {} --log-off http_requests", LogFile.string());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
+ CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ Instance.Shutdown();
+
+ // Main log file must still have the startup message
+ CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created");
+ const std::string FileLog = ReadFileToString(LogFile);
+ CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog);
+
+ // http.log is created by the RotatingFileSink but the logger is OFF, so
+ // the broadcast "server session id" message must not have been written to it
+ const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log";
+ CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created");
+ const std::string HttpLog = ReadFileToString(HttpLogFile);
+ CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog);
+}
+
+TEST_SUITE_END();
+
+} // namespace zen::tests
+
+#endif
diff --git a/src/zenserver-test/nomad-tests.cpp b/src/zenserver-test/nomad-tests.cpp
new file mode 100644
index 000000000..f8f5a9a30
--- /dev/null
+++ b/src/zenserver-test/nomad-tests.cpp
@@ -0,0 +1,130 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS && ZEN_WITH_NOMAD
+# include "zenserver-test.h"
+# include <zencore/filesystem.h>
+# include <zencore/logging.h>
+# include <zencore/testing.h>
+# include <zencore/timer.h>
+# include <zenhttp/httpclient.h>
+# include <zennomad/nomadclient.h>
+# include <zennomad/nomadprocess.h>
+# include <zenutil/zenserverprocess.h>
+
+# include <fmt/format.h>
+
+namespace zen::tests::nomad_tests {
+
+using namespace std::literals;
+
+TEST_SUITE_BEGIN("server.nomad");
+
+TEST_CASE("nomad.client.lifecycle" * doctest::skip())
+{
+ zen::nomad::NomadProcess NomadProc;
+ NomadProc.SpawnNomadAgent();
+
+ zen::nomad::NomadTestClient Client("http://localhost:4646/");
+
+ // Submit a simple batch job that sleeps briefly
+# if ZEN_PLATFORM_WINDOWS
+ auto Job = Client.SubmitJob("zen-test-job", "cmd.exe", {"/C", "timeout /t 10 /nobreak"});
+# else
+ auto Job = Client.SubmitJob("zen-test-job", "/bin/sleep", {"10"});
+# endif
+ REQUIRE(!Job.Id.empty());
+ CHECK_EQ(Job.Status, "pending");
+
+ // Poll until the job is running (or dead)
+ {
+ Stopwatch Timer;
+ bool FoundRunning = false;
+ while (Timer.GetElapsedTimeMs() < 15000)
+ {
+ auto Status = Client.GetJobStatus("zen-test-job");
+ if (Status.Status == "running")
+ {
+ FoundRunning = true;
+ break;
+ }
+ if (Status.Status == "dead")
+ {
+ break;
+ }
+ Sleep(500);
+ }
+ CHECK(FoundRunning);
+ }
+
+ // Verify allocations exist
+ auto Allocs = Client.GetAllocations("zen-test-job");
+ CHECK(!Allocs.empty());
+
+ // Stop the job
+ Client.StopJob("zen-test-job");
+
+ // Verify it reaches dead state
+ {
+ Stopwatch Timer;
+ bool FoundDead = false;
+ while (Timer.GetElapsedTimeMs() < 10000)
+ {
+ auto Status = Client.GetJobStatus("zen-test-job");
+ if (Status.Status == "dead")
+ {
+ FoundDead = true;
+ break;
+ }
+ Sleep(500);
+ }
+ CHECK(FoundDead);
+ }
+
+ NomadProc.StopNomadAgent();
+}
+
+TEST_CASE("nomad.provisioner.integration" * doctest::skip())
+{
+ zen::nomad::NomadProcess NomadProc;
+ NomadProc.SpawnNomadAgent();
+
+ // Spawn zenserver in compute mode with Nomad provisioning enabled
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+
+ std::filesystem::path ZenServerPath = TestEnv.ProgramBaseDir() / "zenserver" ZEN_EXE_SUFFIX_LITERAL;
+
+ std::string NomadArgs = fmt::format(
+ "--nomad-enabled=true"
+ " --nomad-server=http://localhost:4646"
+ " --nomad-driver=raw_exec"
+ " --nomad-binary-path={}"
+ " --nomad-max-cores=32"
+ " --nomad-cores-per-job=32",
+ ZenServerPath.string());
+
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(NomadArgs);
+ REQUIRE(Port != 0);
+
+ // Give the provisioner time to submit jobs.
+ // The management thread has a 5s wait between cycles, and the HTTP client has
+ // a 10s connect timeout, so we need to allow enough time for at least one full cycle.
+ Sleep(15000);
+
+ // Verify jobs were submitted to Nomad
+ zen::nomad::NomadTestClient NomadClient("http://localhost:4646/");
+
+ auto Jobs = NomadClient.ListJobs("zenserver-worker");
+
+ ZEN_INFO("nomad.provisioner.integration: found {} jobs with prefix 'zenserver-worker'", Jobs.size());
+ CHECK_MESSAGE(!Jobs.empty(), Instance.GetLogOutput());
+
+ Instance.Shutdown();
+ NomadProc.StopNomadAgent();
+}
+
+TEST_SUITE_END();
+
+} // namespace zen::tests::nomad_tests
+#endif
diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp
new file mode 100644
index 000000000..f3db5fdf6
--- /dev/null
+++ b/src/zenserver-test/objectstore-tests.cpp
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS
+# include "zenserver-test.h"
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zenutil/zenserverprocess.h>
+# include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <tsl/robin_set.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::tests {
+
+using namespace std::literals;
+
+TEST_SUITE_BEGIN("server.objectstore");
+
+TEST_CASE("objectstore.blobs")
+{
+ std::string_view Bucket = "bkt"sv;
+
+ std::vector<IoHash> CompressedBlobsHashes;
+ std::vector<uint64_t> BlobsSizes;
+ std::vector<uint64_t> CompressedBlobsSizes;
+ {
+ ZenServerInstance Instance(TestEnv);
+
+ const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--objectstore-enabled"));
+ CHECK(PortNumber != 0);
+
+ HttpClient Client(Instance.GetBaseUri() + "/obj/");
+
+ for (size_t I = 0; I < 5; I++)
+ {
+ IoBuffer Blob = CreateSemiRandomBlob(4711 + I * 7);
+ BlobsSizes.push_back(Blob.GetSize());
+ CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob)));
+ CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash());
+ CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize());
+ IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+
+ std::string ObjectPath = fmt::format("{}/{}.utoc",
+ CompressedBlobsHashes.back().ToHexString().substr(0, 2),
+ CompressedBlobsHashes.back().ToHexString());
+
+ HttpClient::Response Result = Client.Put(fmt::format("bucket/{}/{}.utoc", Bucket, ObjectPath), Payload);
+ CHECK(Result);
+ }
+
+ for (size_t I = 0; I < 5; I++)
+ {
+ std::string ObjectPath =
+ fmt::format("{}/{}.utoc", CompressedBlobsHashes[I].ToHexString().substr(0, 2), CompressedBlobsHashes[I].ToHexString());
+ HttpClient::Response Result = Client.Get(fmt::format("bucket/{}/{}.utoc", Bucket, ObjectPath));
+ CHECK(Result);
+ CHECK_EQ(Result.ResponsePayload.GetSize(), CompressedBlobsSizes[I]);
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer(std::move(Result.ResponsePayload)), RawHash, RawSize);
+ CHECK(Compressed);
+ CHECK_EQ(RawHash, CompressedBlobsHashes[I]);
+ CHECK_EQ(RawSize, BlobsSizes[I]);
+ }
+ }
+}
+
+TEST_SUITE_END();
+
+} // namespace zen::tests
+#endif
diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp
index ead062628..eb2e187d7 100644
--- a/src/zenserver-test/projectstore-tests.cpp
+++ b/src/zenserver-test/projectstore-tests.cpp
@@ -27,6 +27,8 @@ namespace zen::tests {
using namespace std::literals;
+TEST_SUITE_BEGIN("server.projectstore");
+
TEST_CASE("project.basic")
{
using namespace std::literals;
@@ -71,7 +73,7 @@ TEST_CASE("project.basic")
{
auto Response = Http.Get("/prj/test"sv);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
CbObject ResponseObject = Response.AsObject();
@@ -92,7 +94,7 @@ TEST_CASE("project.basic")
{
auto Response = Http.Get(""sv);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
CbObject ResponseObject = Response.AsObject();
@@ -213,7 +215,7 @@ TEST_CASE("project.basic")
auto Response = Http.Get(ChunkGetUri);
REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
IoBuffer Data = Response.ResponsePayload;
IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath);
@@ -235,13 +237,13 @@ TEST_CASE("project.basic")
auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
IoBuffer Data = Response.ResponsePayload;
IoHash RawHash;
uint64_t RawSize;
CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
- CHECK(Compressed);
+ REQUIRE(Compressed);
IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath);
CHECK(RawSize == ReferenceData.GetSize());
@@ -436,14 +438,14 @@ TEST_CASE("project.remote")
HttpClient Http{UrlBase};
HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), ProjectPayload);
- CHECK(Response);
+ REQUIRE(Response);
};
auto MakeOplog = [](std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName) {
HttpClient Http{UrlBase};
HttpClient::Response Response =
Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject);
- CHECK(Response);
+ REQUIRE(Response);
};
auto MakeOp = [](std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) {
@@ -454,7 +456,7 @@ TEST_CASE("project.remote")
HttpClient Http{UrlBase};
HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body);
- CHECK(Response);
+ REQUIRE(Response);
};
MakeProject(Servers.GetInstance(0).GetBaseUri(), "proj0");
@@ -505,7 +507,7 @@ TEST_CASE("project.remote")
HttpClient::Response Response =
Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", Project, Oplog), Payload, {{"Accept", "application/x-ue-cbpkg"}});
- CHECK(Response);
+ REQUIRE(Response);
CbPackage ResponsePackage = ParsePackageMessage(Response.ResponsePayload);
CHECK(ResponsePackage.GetAttachments().size() == AttachmentHashes.size());
for (auto A : ResponsePackage.GetAttachments())
@@ -520,7 +522,7 @@ TEST_CASE("project.remote")
HttpClient Http{Servers.GetInstance(ServerIndex).GetBaseUri()};
HttpClient::Response Response = Http.Get(fmt::format("/prj/{}/oplog/{}/entries", Project, Oplog));
- CHECK(Response);
+ REQUIRE(Response);
IoBuffer Payload(Response.ResponsePayload);
CbObject OplogResonse = LoadCompactBinaryObject(Payload);
@@ -542,7 +544,7 @@ TEST_CASE("project.remote")
auto HttpWaitForCompletion = [](ZenServerInstance& Server, const HttpClient::Response& Response) {
REQUIRE(Response);
const uint64_t JobId = ParseInt<uint64_t>(Response.AsText()).value_or(0);
- CHECK(JobId != 0);
+ REQUIRE(JobId != 0);
HttpClient Http{Server.GetBaseUri()};
@@ -550,10 +552,10 @@ TEST_CASE("project.remote")
{
HttpClient::Response StatusResponse =
Http.Get(fmt::format("/admin/jobs/{}", JobId), {{"Accept", ToString(ZenContentType::kCbObject)}});
- CHECK(StatusResponse);
+ REQUIRE(StatusResponse);
CbObject ResponseObject = StatusResponse.AsObject();
std::string_view Status = ResponseObject["Status"sv].AsString();
- CHECK(Status != "Aborted"sv);
+ REQUIRE(Status != "Aborted"sv);
if (Status == "Complete"sv)
{
return;
@@ -888,17 +890,17 @@ TEST_CASE("project.rpcappendop")
Project.AddString("project"sv, ""sv);
Project.AddString("projectfile"sv, ""sv);
HttpClient::Response Response = Client.Post(fmt::format("/prj/{}", ProjectName), Project.Save());
- CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
};
auto MakeOplog = [](HttpClient& Client, std::string_view ProjectName, std::string_view OplogName) {
HttpClient::Response Response =
Client.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject);
- CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
};
auto GetOplog = [](HttpClient& Client, std::string_view ProjectName, std::string_view OplogName) {
HttpClient::Response Response = Client.Get(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName));
- CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
return Response.AsObject();
};
@@ -912,7 +914,7 @@ TEST_CASE("project.rpcappendop")
}
Request.EndArray(); // "ops"
HttpClient::Response Response = Client.Post(fmt::format("/prj/{}/oplog/{}/rpc", ProjectName, OplogName), Request.Save());
- CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage(""));
CbObjectView ResponsePayload = Response.AsPackage().GetObject();
CbArrayView NeedArray = ResponsePayload["need"sv].AsArrayView();
@@ -1055,6 +1057,8 @@ TEST_CASE("project.rpcappendop")
}
}
+TEST_SUITE_END();
+
} // namespace zen::tests
#endif
diff --git a/src/zenserver-test/workspace-tests.cpp b/src/zenserver-test/workspace-tests.cpp
index 7595d790a..655f28872 100644
--- a/src/zenserver-test/workspace-tests.cpp
+++ b/src/zenserver-test/workspace-tests.cpp
@@ -73,6 +73,8 @@ GenerateFolderContent2(const std::filesystem::path& RootPath)
return Result;
}
+TEST_SUITE_BEGIN("server.workspace");
+
TEST_CASE("workspaces.create")
{
using namespace std::literals;
@@ -514,9 +516,9 @@ TEST_CASE("workspaces.share")
}
IoBuffer BatchResponse =
Client.Post(fmt::format("/ws/{}/{}/batch", WorkspaceId, ShareId), BuildChunkBatchRequest(BatchEntries)).ResponsePayload;
- CHECK(BatchResponse);
+ REQUIRE(BatchResponse);
std::vector<IoBuffer> BatchResult = ParseChunkBatchResponse(BatchResponse);
- CHECK(BatchResult.size() == Files.size());
+ REQUIRE(BatchResult.size() == Files.size());
for (const RequestChunkEntry& Request : BatchEntries)
{
IoBuffer Result = BatchResult[Request.CorrelationId];
@@ -537,5 +539,7 @@ TEST_CASE("workspaces.share")
CHECK(Client.Get(fmt::format("/ws/{}", WorkspaceId)).StatusCode == HttpResponseCode::NotFound);
}
+TEST_SUITE_END();
+
} // namespace zen::tests
#endif
diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua
index 2a269cea1..7b208bbc7 100644
--- a/src/zenserver-test/xmake.lua
+++ b/src/zenserver-test/xmake.lua
@@ -6,10 +6,15 @@ target("zenserver-test")
add_headerfiles("**.h")
add_files("*.cpp")
add_files("zenserver-test.cpp", {unity_ignored = true })
- add_deps("zencore", "zenremotestore", "zenhttp")
+ add_deps("zencore", "zenremotestore", "zenhttp", "zencompute", "zenstore")
add_deps("zenserver", {inherit=false})
+ add_deps("zentest-appstub", {inherit=false})
add_packages("http_parser")
+ if has_config("zennomad") then
+ add_deps("zennomad")
+ end
+
if is_plat("macosx") then
add_ldflags("-framework CoreFoundation")
add_ldflags("-framework Security")
diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp
index 9a42bb73d..8d5400294 100644
--- a/src/zenserver-test/zenserver-test.cpp
+++ b/src/zenserver-test/zenserver-test.cpp
@@ -4,12 +4,12 @@
#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
# include "zenserver-test.h"
# include <zencore/except.h>
# include <zencore/fmtutils.h>
# include <zencore/logging.h>
+# include <zencore/logging/registry.h>
# include <zencore/stream.h>
# include <zencore/string.h>
# include <zencore/testutils.h>
@@ -17,8 +17,8 @@
# include <zencore/timer.h>
# include <zenhttp/httpclient.h>
# include <zenhttp/packageformat.h>
-# include <zenutil/commandlineoptions.h>
-# include <zenutil/logging/testformatter.h>
+# include <zenutil/config/commandlineoptions.h>
+# include <zenutil/logging/fullformatter.h>
# include <zenutil/zenserverprocess.h>
# include <atomic>
@@ -86,8 +86,9 @@ main(int argc, char** argv)
zen::logging::InitializeLogging();
- zen::logging::SetLogLevel(zen::logging::level::Debug);
- spdlog::set_formatter(std::make_unique<zen::logging::full_test_formatter>("test", std::chrono::system_clock::now()));
+ zen::logging::SetLogLevel(zen::logging::Debug);
+ zen::logging::Registry::Instance().SetFormatter(
+ std::make_unique<zen::logging::FullFormatter>("test", std::chrono::system_clock::now()));
std::filesystem::path ProgramBaseDir = GetRunningExecutablePath().parent_path();
std::filesystem::path TestBaseDir = std::filesystem::current_path() / ".test";
@@ -97,6 +98,7 @@ main(int argc, char** argv)
// somehow in the future
std::string ServerClass;
+ bool Verbose = false;
for (int i = 1; i < argc; ++i)
{
@@ -107,13 +109,23 @@ main(int argc, char** argv)
ServerClass = argv[++i];
}
}
+ else if (argv[i] == "--verbose"sv)
+ {
+ Verbose = true;
+ }
}
zen::tests::TestEnv.InitializeForTest(ProgramBaseDir, TestBaseDir, ServerClass);
+ if (Verbose)
+ {
+ zen::tests::TestEnv.SetPassthroughOutput(true);
+ }
+
ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir);
zen::testing::TestRunner Runner;
+ Runner.SetDefaultSuiteFilter("server.*");
Runner.ApplyCommandLine(argc, argv);
return Runner.Run();
@@ -121,6 +133,8 @@ main(int argc, char** argv)
namespace zen::tests {
+TEST_SUITE_BEGIN("server.zenserver");
+
TEST_CASE("default.single")
{
std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
@@ -327,6 +341,8 @@ TEST_CASE("http.package")
CHECK_EQ(ResponsePackage, TestPackage);
}
+TEST_SUITE_END();
+
# if 0
TEST_CASE("lifetime.owner")
{
diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp
new file mode 100644
index 000000000..c64f081b3
--- /dev/null
+++ b/src/zenserver/compute/computeserver.cpp
@@ -0,0 +1,1021 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "computeserver.h"
+#include <zencompute/cloudmetadata.h>
+#include <zencompute/httpcomputeservice.h>
+#include <zencompute/httporchestrator.h>
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/fmtutils.h>
+# include <zencore/memory/llm.h>
+# include <zencore/memory/memorytrace.h>
+# include <zencore/memory/tagtrace.h>
+# include <zencore/scopeguard.h>
+# include <zencore/sentryintegration.h>
+# include <zencore/system.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/windows.h>
+# include <zenhttp/httpclient.h>
+# include <zenhttp/httpapiservice.h>
+# include <zenstore/cidstore.h>
+# include <zenutil/service.h>
+# if ZEN_WITH_HORDE
+# include <zenhorde/hordeconfig.h>
+# include <zenhorde/hordeprovisioner.h>
+# endif
+# if ZEN_WITH_NOMAD
+# include <zennomad/nomadconfig.h>
+# include <zennomad/nomadprovisioner.h>
+# endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <cxxopts.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+void
+ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
+{
+ Options.add_option("compute",
+ "",
+ "max-actions",
+ "Maximum number of concurrent local actions (0 = auto)",
+ cxxopts::value<int32_t>(m_ServerOptions.MaxConcurrentActions)->default_value("0"),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "upstream-notification-endpoint",
+ "Endpoint URL for upstream notifications",
+ cxxopts::value<std::string>(m_ServerOptions.UpstreamNotificationEndpoint)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "instance-id",
+ "Instance ID for use in notifications",
+ cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "coordinator-endpoint",
+ "Endpoint URL for coordinator service",
+ cxxopts::value<std::string>(m_ServerOptions.CoordinatorEndpoint)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "idms",
+ "Enable IDMS cloud detection; optionally specify a custom probe endpoint",
+ cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "worker-websocket",
+ "Use WebSocket for worker-orchestrator link (instant reachability detection)",
+ cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"),
+ "");
+
+# if ZEN_WITH_HORDE
+ // Horde provisioning options
+ Options.add_option("horde",
+ "",
+ "horde-enabled",
+ "Enable Horde worker provisioning",
+ cxxopts::value<bool>(m_ServerOptions.HordeConfig.Enabled)->default_value("false"),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-server",
+ "Horde server URL",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.ServerUrl)->default_value(""),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-token",
+ "Horde authentication token",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.AuthToken)->default_value(""),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-pool",
+ "Horde pool name",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Pool)->default_value(""),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-cluster",
+ "Horde cluster ID ('default' or '_auto' for auto-resolve)",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Cluster)->default_value("default"),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-mode",
+ "Horde connection mode (direct, tunnel, relay)",
+ cxxopts::value<std::string>(m_HordeModeStr)->default_value("direct"),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-encryption",
+ "Horde transport encryption (none, aes)",
+ cxxopts::value<std::string>(m_HordeEncryptionStr)->default_value("none"),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-max-cores",
+ "Maximum number of Horde cores to provision",
+ cxxopts::value<int>(m_ServerOptions.HordeConfig.MaxCores)->default_value("2048"),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-host",
+ "Host address for Horde agents to connect back to",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-condition",
+ "Additional Horde agent filter condition",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Condition)->default_value(""),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-binaries",
+ "Path to directory containing zenserver binary for remote upload",
+ cxxopts::value<std::string>(m_ServerOptions.HordeConfig.BinariesPath)->default_value(""),
+ "");
+
+ Options.add_option("horde",
+ "",
+ "horde-zen-service-port",
+ "Port number for Zen service communication",
+ cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"),
+ "");
+# endif
+
+# if ZEN_WITH_NOMAD
+ // Nomad provisioning options
+ Options.add_option("nomad",
+ "",
+ "nomad-enabled",
+ "Enable Nomad worker provisioning",
+ cxxopts::value<bool>(m_ServerOptions.NomadConfig.Enabled)->default_value("false"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-server",
+ "Nomad HTTP API URL",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.ServerUrl)->default_value(""),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-token",
+ "Nomad ACL token",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.AclToken)->default_value(""),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-datacenter",
+ "Nomad target datacenter",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Datacenter)->default_value("dc1"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-namespace",
+ "Nomad namespace",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Namespace)->default_value("default"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-region",
+ "Nomad region (empty for server default)",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Region)->default_value(""),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-driver",
+ "Nomad task driver (raw_exec, docker)",
+ cxxopts::value<std::string>(m_NomadDriverStr)->default_value("raw_exec"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-distribution",
+ "Binary distribution mode (predeployed, artifact)",
+ cxxopts::value<std::string>(m_NomadDistributionStr)->default_value("predeployed"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-binary-path",
+ "Path to zenserver on Nomad clients (predeployed mode)",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.BinaryPath)->default_value(""),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-artifact-source",
+ "URL to download zenserver binary (artifact mode)",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.ArtifactSource)->default_value(""),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-docker-image",
+ "Docker image for zenserver (docker driver)",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.DockerImage)->default_value(""),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-max-jobs",
+ "Maximum concurrent Nomad jobs",
+ cxxopts::value<int>(m_ServerOptions.NomadConfig.MaxJobs)->default_value("64"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-cpu-mhz",
+ "CPU MHz allocated per Nomad task",
+ cxxopts::value<int>(m_ServerOptions.NomadConfig.CpuMhz)->default_value("1000"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-memory-mb",
+ "Memory MB allocated per Nomad task",
+ cxxopts::value<int>(m_ServerOptions.NomadConfig.MemoryMb)->default_value("2048"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-cores-per-job",
+ "Estimated cores per Nomad job (for scaling)",
+ cxxopts::value<int>(m_ServerOptions.NomadConfig.CoresPerJob)->default_value("32"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-max-cores",
+ "Maximum total cores to provision via Nomad",
+ cxxopts::value<int>(m_ServerOptions.NomadConfig.MaxCores)->default_value("2048"),
+ "");
+
+ Options.add_option("nomad",
+ "",
+ "nomad-job-prefix",
+ "Prefix for generated Nomad job IDs",
+ cxxopts::value<std::string>(m_ServerOptions.NomadConfig.JobPrefix)->default_value("zenserver-worker"),
+ "");
+# endif
+}
+
+void
+ZenComputeServerConfigurator::AddConfigOptions(LuaConfig::Options& Options)
+{
+ ZEN_UNUSED(Options);
+}
+
+void
+ZenComputeServerConfigurator::ApplyOptions(cxxopts::Options& Options)
+{
+ ZEN_UNUSED(Options);
+}
+
+void
+ZenComputeServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions)
+{
+ ZEN_UNUSED(LuaOptions);
+}
+
+void
+ZenComputeServerConfigurator::ValidateOptions()
+{
+# if ZEN_WITH_HORDE
+ horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr);
+ horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr);
+# endif
+
+# if ZEN_WITH_NOMAD
+ nomad::FromString(m_ServerOptions.NomadConfig.TaskDriver, m_NomadDriverStr);
+ nomad::FromString(m_ServerOptions.NomadConfig.BinDistribution, m_NomadDistributionStr);
+# endif
+}
+
+///////////////////////////////////////////////////////////////////////////
+
+ZenComputeServer::ZenComputeServer()
+{
+}
+
+ZenComputeServer::~ZenComputeServer()
+{
+ Cleanup();
+}
+
+int
+ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry)
+{
+ ZEN_TRACE_CPU("ZenComputeServer::Initialize");
+ ZEN_MEMSCOPE(GetZenserverTag());
+
+ ZEN_INFO(ZEN_APP_NAME " initializing in COMPUTE server mode");
+
+ const int EffectiveBasePort = ZenServerBase::Initialize(ServerConfig, ServerEntry);
+ if (EffectiveBasePort < 0)
+ {
+ return EffectiveBasePort;
+ }
+
+ m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint;
+ m_InstanceId = ServerConfig.InstanceId;
+ m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket;
+
+ // This is a workaround to make sure we can have automated tests. Without
+ // this the ranges for different child zen compute processes could overlap with
+ // the main test range.
+ ZenServerEnvironment::SetBaseChildId(2000);
+
+ m_DebugOptionForcedCrash = ServerConfig.ShouldCrash;
+
+ InitializeState(ServerConfig);
+ InitializeServices(ServerConfig);
+ RegisterServices(ServerConfig);
+
+ ZenServerBase::Finalize();
+
+ return EffectiveBasePort;
+}
+
+void
+ZenComputeServer::Cleanup()
+{
+ ZEN_TRACE_CPU("ZenComputeServer::Cleanup");
+ ZEN_INFO(ZEN_APP_NAME " cleaning up");
+ try
+ {
+ // Cancel the maintenance timer so it stops re-enqueuing before we
+ // tear down the provisioners it references.
+ m_ProvisionerMaintenanceTimer.cancel();
+ m_AnnounceTimer.cancel();
+
+# if ZEN_WITH_HORDE
+ // Shut down Horde provisioner first — this signals all agent threads
+ // to exit and joins them before we tear down HTTP services.
+ m_HordeProvisioner.reset();
+# endif
+
+# if ZEN_WITH_NOMAD
+ // Shut down Nomad provisioner — stops the management thread and
+ // sends stop requests for all tracked jobs.
+ m_NomadProvisioner.reset();
+# endif
+
+ // Close the orchestrator WebSocket client before stopping the io_context
+ m_WsReconnectTimer.cancel();
+ if (m_OrchestratorWsClient)
+ {
+ m_OrchestratorWsClient->Close();
+ m_OrchestratorWsClient.reset();
+ }
+ m_OrchestratorWsHandler.reset();
+
+ ResolveCloudMetadata();
+ m_CloudMetadata.reset();
+
+ // Shut down services that own threads or use the io_context before we
+ // stop the io_context and close the HTTP server.
+ if (m_OrchestratorService)
+ {
+ m_OrchestratorService->Shutdown();
+ }
+ if (m_ComputeService)
+ {
+ m_ComputeService->Shutdown();
+ }
+
+ m_IoContext.stop();
+ if (m_IoRunner.joinable())
+ {
+ m_IoRunner.join();
+ }
+
+ ShutdownServices();
+
+ if (m_Http)
+ {
+ m_Http->Close();
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("exception thrown during Cleanup() in {}: '{}'", ZEN_APP_NAME, Ex.what());
+ }
+}
+
+void
+ZenComputeServer::InitializeState(const ZenComputeServerConfig& ServerConfig)
+{
+ ZEN_UNUSED(ServerConfig);
+}
+
+void
+ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
+{
+ ZEN_TRACE_CPU("ZenComputeServer::InitializeServices");
+ ZEN_INFO("initializing compute services");
+
+ CidStoreConfiguration Config;
+ Config.RootDirectory = m_DataRoot / "cas";
+
+ m_CidStore = std::make_unique<CidStore>(m_GcManager);
+ m_CidStore->Initialize(Config);
+
+ if (!ServerConfig.IdmsEndpoint.empty())
+ {
+ ZEN_INFO("detecting cloud environment (async)");
+ if (ServerConfig.IdmsEndpoint == "auto")
+ {
+ m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir] {
+ return std::make_unique<zen::compute::CloudMetadata>(DataDir / "cloud");
+ });
+ }
+ else
+ {
+ ZEN_INFO("using custom IDMS endpoint: {}", ServerConfig.IdmsEndpoint);
+ m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir, Endpoint = ServerConfig.IdmsEndpoint] {
+ return std::make_unique<zen::compute::CloudMetadata>(DataDir / "cloud", Endpoint);
+ });
+ }
+ }
+
+ ZEN_INFO("instantiating API service");
+ m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http);
+
+ ZEN_INFO("instantiating orchestrator service");
+ m_OrchestratorService =
+ std::make_unique<zen::compute::HttpOrchestratorService>(ServerConfig.DataDir / "orch", ServerConfig.EnableWorkerWebSocket);
+
+ ZEN_INFO("instantiating function service");
+ m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_CidStore,
+ m_StatsService,
+ ServerConfig.DataDir / "functions",
+ ServerConfig.MaxConcurrentActions);
+
+ m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService);
+
+# if ZEN_WITH_NOMAD
+ // Nomad provisioner
+ if (ServerConfig.NomadConfig.Enabled && !ServerConfig.NomadConfig.ServerUrl.empty())
+ {
+ ZEN_INFO("instantiating Nomad provisioner (server: {})", ServerConfig.NomadConfig.ServerUrl);
+
+ const auto& NomadCfg = ServerConfig.NomadConfig;
+
+ if (!NomadCfg.Validate())
+ {
+ ZEN_ERROR("invalid Nomad configuration");
+ }
+ else
+ {
+ ExtendableStringBuilder<256> OrchestratorEndpoint;
+ OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get());
+ if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/')
+ {
+ OrchestratorEndpoint << '/';
+ }
+
+ m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint);
+ }
+ }
+# endif
+
+# if ZEN_WITH_HORDE
+ // Horde provisioner
+ if (ServerConfig.HordeConfig.Enabled && !ServerConfig.HordeConfig.ServerUrl.empty())
+ {
+ ZEN_INFO("instantiating Horde provisioner (server: {})", ServerConfig.HordeConfig.ServerUrl);
+
+ const auto& HordeConfig = ServerConfig.HordeConfig;
+
+ if (!HordeConfig.Validate())
+ {
+ ZEN_ERROR("invalid Horde configuration");
+ }
+ else
+ {
+ ExtendableStringBuilder<256> OrchestratorEndpoint;
+ OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get());
+ if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/')
+ {
+ OrchestratorEndpoint << '/';
+ }
+
+ // If no binaries path is specified, just use the running executable's directory
+ std::filesystem::path BinariesPath = HordeConfig.BinariesPath.empty() ? GetRunningExecutablePath().parent_path()
+ : std::filesystem::path(HordeConfig.BinariesPath);
+ std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde";
+
+ m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint);
+ }
+ }
+# endif
+}
+
+void
+ZenComputeServer::ResolveCloudMetadata()
+{
+ if (m_CloudMetadataFuture.valid())
+ {
+ m_CloudMetadata = m_CloudMetadataFuture.get();
+ }
+}
+
+std::string
+ZenComputeServer::GetInstanceId() const
+{
+ if (!m_InstanceId.empty())
+ {
+ return m_InstanceId;
+ }
+ return fmt::format("{}-{}", GetMachineName(), GetCurrentProcessId());
+}
+
+std::string
+ZenComputeServer::GetAnnounceUrl() const
+{
+ return m_Http->GetServiceUri(nullptr);
+}
+
+void
+ZenComputeServer::RegisterServices(const ZenComputeServerConfig& ServerConfig)
+{
+ ZEN_TRACE_CPU("ZenComputeServer::RegisterServices");
+ ZEN_UNUSED(ServerConfig);
+
+ if (m_ApiService)
+ {
+ m_Http->RegisterService(*m_ApiService);
+ }
+
+ if (m_OrchestratorService)
+ {
+ m_Http->RegisterService(*m_OrchestratorService);
+ }
+
+ if (m_ComputeService)
+ {
+ m_Http->RegisterService(*m_ComputeService);
+ }
+
+ if (m_FrontendService)
+ {
+ m_Http->RegisterService(*m_FrontendService);
+ }
+}
+
+CbObject
+ZenComputeServer::BuildAnnounceBody()
+{
+ CbObjectWriter AnnounceBody;
+ AnnounceBody << "id" << GetInstanceId();
+ AnnounceBody << "uri" << GetAnnounceUrl();
+ AnnounceBody << "hostname" << GetMachineName();
+ AnnounceBody << "platform" << GetRuntimePlatformName();
+
+ ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query());
+
+ AnnounceBody.BeginObject("metrics");
+ Describe(Sm, AnnounceBody);
+ AnnounceBody.EndObject();
+
+ AnnounceBody << "cpu_usage" << Sm.CpuUsagePercent;
+ AnnounceBody << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024;
+ AnnounceBody << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024;
+
+ AnnounceBody << "bytes_received" << m_Http->GetTotalBytesReceived();
+ AnnounceBody << "bytes_sent" << m_Http->GetTotalBytesSent();
+
+ auto Actions = m_ComputeService->GetActionCounts();
+ AnnounceBody << "actions_pending" << Actions.Pending;
+ AnnounceBody << "actions_running" << Actions.Running;
+ AnnounceBody << "actions_completed" << Actions.Completed;
+ AnnounceBody << "active_queues" << Actions.ActiveQueues;
+
+ // Derive provisioner from instance ID prefix (e.g. "horde-xxx" or "nomad-xxx")
+ if (m_InstanceId.starts_with("horde-"))
+ {
+ AnnounceBody << "provisioner"
+ << "horde";
+ }
+ else if (m_InstanceId.starts_with("nomad-"))
+ {
+ AnnounceBody << "provisioner"
+ << "nomad";
+ }
+
+ ResolveCloudMetadata();
+ if (m_CloudMetadata)
+ {
+ m_CloudMetadata->Describe(AnnounceBody);
+ }
+
+ return AnnounceBody.Save();
+}
+
+void
+ZenComputeServer::PostAnnounce()
+{
+ ZEN_TRACE_CPU("ZenComputeServer::PostAnnounce");
+
+ if (!m_ComputeService || m_CoordinatorEndpoint.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("notifying coordinator at '{}' of our availability at '{}'", m_CoordinatorEndpoint, GetAnnounceUrl());
+
+ try
+ {
+ CbObject Body = BuildAnnounceBody();
+
+ // If we have an active WebSocket connection, send via that instead of HTTP POST
+ if (m_OrchestratorWsClient && m_OrchestratorWsClient->IsOpen())
+ {
+ MemoryView View = Body.GetView();
+ m_OrchestratorWsClient->SendBinary(std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(View.GetData()), View.GetSize()));
+ ZEN_INFO("announced to coordinator via WebSocket");
+ return;
+ }
+
+ HttpClient CoordinatorHttp(m_CoordinatorEndpoint);
+ HttpClient::Response Result = CoordinatorHttp.Post("announce", std::move(Body));
+
+ if (Result.Error)
+ {
+ ZEN_ERROR("failed to notify coordinator at '{}': HTTP error {} - {}",
+ m_CoordinatorEndpoint,
+ Result.Error->ErrorCode,
+ Result.Error->ErrorMessage);
+ }
+ else if (!IsHttpOk(Result.StatusCode))
+ {
+ ZEN_ERROR("failed to notify coordinator at '{}': unexpected HTTP status code {}",
+ m_CoordinatorEndpoint,
+ static_cast<int>(Result.StatusCode));
+ }
+ else
+ {
+ ZEN_INFO("successfully notified coordinator at '{}'", m_CoordinatorEndpoint);
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("failed to notify coordinator at '{}': {}", m_CoordinatorEndpoint, Ex.what());
+ }
+}
+
+void
+ZenComputeServer::EnqueueAnnounceTimer()
+{
+ if (!m_ComputeService || m_CoordinatorEndpoint.empty())
+ {
+ return;
+ }
+
+ m_AnnounceTimer.expires_after(std::chrono::seconds(15));
+ m_AnnounceTimer.async_wait([this](const asio::error_code& Ec) {
+ if (!Ec)
+ {
+ PostAnnounce();
+ EnqueueAnnounceTimer();
+ }
+ });
+ EnsureIoRunner();
+}
+
+void
+ZenComputeServer::InitializeOrchestratorWebSocket()
+{
+ if (!m_EnableWorkerWebSocket || m_CoordinatorEndpoint.empty())
+ {
+ return;
+ }
+
+ // Convert http://host:port → ws://host:port/orch/ws
+ std::string WsUrl = m_CoordinatorEndpoint;
+ if (WsUrl.starts_with("http://"))
+ {
+ WsUrl = "ws://" + WsUrl.substr(7);
+ }
+ else if (WsUrl.starts_with("https://"))
+ {
+ WsUrl = "wss://" + WsUrl.substr(8);
+ }
+ if (!WsUrl.empty() && WsUrl.back() != '/')
+ {
+ WsUrl += '/';
+ }
+ WsUrl += "orch/ws";
+
+ ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl);
+
+ m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this);
+ m_OrchestratorWsClient =
+ std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, m_IoContext, HttpWsClientSettings{.LogCategory = "orch_ws"});
+
+ m_OrchestratorWsClient->Connect();
+ EnsureIoRunner();
+}
+
+void
+ZenComputeServer::EnqueueWsReconnect()
+{
+ m_WsReconnectTimer.expires_after(std::chrono::seconds(5));
+ m_WsReconnectTimer.async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && m_OrchestratorWsClient)
+ {
+ ZEN_INFO("attempting WebSocket reconnect to orchestrator");
+ m_OrchestratorWsClient->Connect();
+ }
+ });
+ EnsureIoRunner();
+}
+
+void
+ZenComputeServer::OrchestratorWsHandler::OnWsOpen()
+{
+ ZEN_INFO("WebSocket link to orchestrator established");
+
+ // Send initial announce immediately over the WebSocket
+ Server.PostAnnounce();
+}
+
+void
+ZenComputeServer::OrchestratorWsHandler::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg)
+{
+ // Orchestrator does not push messages to workers; ignore
+}
+
+void
+ZenComputeServer::OrchestratorWsHandler::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_WARN("WebSocket link to orchestrator closed (code {}), falling back to HTTP announce", Code);
+
+ // Trigger an immediate HTTP announce so the orchestrator has fresh state,
+ // then schedule a reconnect attempt.
+ Server.PostAnnounce();
+ Server.EnqueueWsReconnect();
+}
+
+void
+ZenComputeServer::ProvisionerMaintenanceTick()
+{
+# if ZEN_WITH_HORDE
+ if (m_HordeProvisioner)
+ {
+ m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX);
+ auto Stats = m_HordeProvisioner->GetStats();
+ ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}",
+ Stats.TargetCoreCount,
+ Stats.EstimatedCoreCount,
+ Stats.ActiveCoreCount);
+ }
+# endif
+
+# if ZEN_WITH_NOMAD
+ if (m_NomadProvisioner)
+ {
+ m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX);
+ auto Stats = m_NomadProvisioner->GetStats();
+ ZEN_DEBUG("Nomad maintenance: target={}, estimated={}, running jobs={}",
+ Stats.TargetCoreCount,
+ Stats.EstimatedCoreCount,
+ Stats.RunningJobCount);
+ }
+# endif
+}
+
+void
+ZenComputeServer::EnqueueProvisionerMaintenanceTimer()
+{
+ bool HasProvisioner = false;
+# if ZEN_WITH_HORDE
+ HasProvisioner = HasProvisioner || (m_HordeProvisioner != nullptr);
+# endif
+# if ZEN_WITH_NOMAD
+ HasProvisioner = HasProvisioner || (m_NomadProvisioner != nullptr);
+# endif
+
+ if (!HasProvisioner)
+ {
+ return;
+ }
+
+ m_ProvisionerMaintenanceTimer.expires_after(std::chrono::seconds(15));
+ m_ProvisionerMaintenanceTimer.async_wait([this](const asio::error_code& Ec) {
+ if (!Ec)
+ {
+ ProvisionerMaintenanceTick();
+ EnqueueProvisionerMaintenanceTimer();
+ }
+ });
+ EnsureIoRunner();
+}
+
+void
+ZenComputeServer::Run()
+{
+ ZEN_TRACE_CPU("ZenComputeServer::Run");
+
+ if (m_ProcessMonitor.IsActive())
+ {
+ CheckOwnerPid();
+ }
+
+ if (!m_TestMode)
+ {
+ // clang-format off
+ ZEN_INFO( R"(__________ _________ __ )" "\n"
+ R"(\____ /____ ____ \_ ___ \ ____ _____ ______ __ ___/ |_ ____ )" "\n"
+ R"( / // __ \ / \/ \ \/ / _ \ / \\____ \| | \ __\/ __ \ )" "\n"
+ R"( / /\ ___/| | \ \___( <_> ) Y Y \ |_> > | /| | \ ___/ )" "\n"
+ R"(/_______ \___ >___| /\______ /\____/|__|_| / __/|____/ |__| \___ >)" "\n"
+ R"( \/ \/ \/ \/ \/|__| \/ )");
+ // clang-format on
+
+ ExtendableStringBuilder<256> BuildOptions;
+ GetBuildOptions(BuildOptions, '\n');
+ ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions);
+ }
+
+ ZEN_INFO(ZEN_APP_NAME " now running as COMPUTE (pid: {})", GetCurrentProcessId());
+
+# if ZEN_PLATFORM_WINDOWS
+ if (zen::windows::IsRunningOnWine())
+ {
+ ZEN_INFO("detected Wine session - " ZEN_APP_NAME " is not formally tested on Wine and may therefore not work or perform well");
+ }
+# endif
+
+# if ZEN_USE_SENTRY
+ ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED");
+ if (m_UseSentry)
+ {
+ SentryIntegration::ClearCaches();
+ }
+# endif
+
+ if (m_DebugOptionForcedCrash)
+ {
+ ZEN_DEBUG_BREAK();
+ }
+
+ const bool IsInteractiveMode = IsInteractiveSession(); // &&!m_TestMode;
+
+ SetNewState(kRunning);
+
+ OnReady();
+
+ PostAnnounce();
+ EnqueueAnnounceTimer();
+ InitializeOrchestratorWebSocket();
+
+# if ZEN_WITH_HORDE
+ // Start Horde provisioning if configured — request maximum allowed cores.
+ // SetTargetCoreCount clamps to HordeConfig::MaxCores internally.
+ if (m_HordeProvisioner)
+ {
+ ZEN_INFO("Horde provisioning starting");
+ m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX);
+ auto Stats = m_HordeProvisioner->GetStats();
+ ZEN_INFO("Horde provisioning started (target cores: {})", Stats.TargetCoreCount);
+ }
+# endif
+
+# if ZEN_WITH_NOMAD
+ // Start Nomad provisioning if configured — request maximum allowed cores.
+ // SetTargetCoreCount clamps to NomadConfig::MaxCores internally.
+ if (m_NomadProvisioner)
+ {
+ m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX);
+ auto Stats = m_NomadProvisioner->GetStats();
+ ZEN_INFO("Nomad provisioning started (target cores: {})", Stats.TargetCoreCount);
+ }
+# endif
+
+ EnqueueProvisionerMaintenanceTimer();
+
+ m_Http->Run(IsInteractiveMode);
+
+ SetNewState(kShuttingDown);
+
+ ZEN_INFO(ZEN_APP_NAME " exiting");
+}
+
+//////////////////////////////////////////////////////////////////////////////////
+
+ZenComputeServerMain::ZenComputeServerMain(ZenComputeServerConfig& ServerOptions)
+: ZenServerMain(ServerOptions)
+, m_ServerOptions(ServerOptions)
+{
+}
+
+void
+ZenComputeServerMain::DoRun(ZenServerState::ZenServerEntry* Entry)
+{
+ ZEN_TRACE_CPU("ZenComputeServerMain::DoRun");
+
+ ZenComputeServer Server;
+ Server.SetDataRoot(m_ServerOptions.DataDir);
+ Server.SetContentRoot(m_ServerOptions.ContentDir);
+ Server.SetTestMode(m_ServerOptions.IsTest);
+ Server.SetDedicatedMode(m_ServerOptions.IsDedicated);
+
+ const int EffectiveBasePort = Server.Initialize(m_ServerOptions, Entry);
+ if (EffectiveBasePort == -1)
+ {
+ // Server.Initialize has already logged what the issue is - just exit with failure code here.
+ std::exit(1);
+ }
+
+ Entry->EffectiveListenPort = uint16_t(EffectiveBasePort);
+ if (EffectiveBasePort != m_ServerOptions.BasePort)
+ {
+ ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort);
+ m_ServerOptions.BasePort = EffectiveBasePort;
+ }
+
+ std::unique_ptr<std::thread> ShutdownThread;
+ std::unique_ptr<NamedEvent> ShutdownEvent;
+
+ ExtendableStringBuilder<64> ShutdownEventName;
+ ShutdownEventName << "Zen_" << m_ServerOptions.BasePort << "_Shutdown";
+ ShutdownEvent.reset(new NamedEvent{ShutdownEventName});
+
+ // Monitor shutdown signals
+
+ ShutdownThread.reset(new std::thread{[&] {
+ SetCurrentThreadName("shutdown_mon");
+
+ ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}' for process {}", ShutdownEventName, zen::GetCurrentProcessId());
+
+ if (ShutdownEvent->Wait())
+ {
+ ZEN_INFO("shutdown signal for pid {} received", zen::GetCurrentProcessId());
+ Server.RequestExit(0);
+ }
+ else
+ {
+ ZEN_INFO("shutdown signal wait() failed");
+ }
+ }});
+
+ auto CleanupShutdown = MakeGuard([&ShutdownEvent, &ShutdownThread] {
+ ReportServiceStatus(ServiceStatus::Stopping);
+
+ if (ShutdownEvent)
+ {
+ ShutdownEvent->Set();
+ }
+ if (ShutdownThread && ShutdownThread->joinable())
+ {
+ ShutdownThread->join();
+ }
+ });
+
+ // If we have a parent process, establish the mechanisms we need
+ // to be able to communicate readiness with the parent
+
+ Server.SetIsReadyFunc([&] {
+ std::error_code Ec;
+ m_LockFile.Update(MakeLockData(true), Ec);
+ ReportServiceStatus(ServiceStatus::Running);
+ NotifyReady();
+ });
+
+ Server.Run();
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h
new file mode 100644
index 000000000..8f4edc0f0
--- /dev/null
+++ b/src/zenserver/compute/computeserver.h
@@ -0,0 +1,188 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenserver.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <future>
+# include <zencore/system.h>
+# include <zenhttp/httpwsclient.h>
+# include <zenstore/gc.h>
+# include "frontend/frontend.h"
+
+namespace cxxopts {
+class Options;
+}
+namespace zen::LuaConfig {
+struct Options;
+}
+
+namespace zen::compute {
+class CloudMetadata;
+class HttpComputeService;
+class HttpOrchestratorService;
+} // namespace zen::compute
+
+# if ZEN_WITH_HORDE
+# include <zenhorde/hordeconfig.h>
+namespace zen::horde {
+class HordeProvisioner;
+} // namespace zen::horde
+# endif
+
+# if ZEN_WITH_NOMAD
+# include <zennomad/nomadconfig.h>
+namespace zen::nomad {
+class NomadProvisioner;
+} // namespace zen::nomad
+# endif
+
+namespace zen {
+
+class CidStore;
+class HttpApiService;
+
+struct ZenComputeServerConfig : public ZenServerConfig
+{
+ std::string UpstreamNotificationEndpoint;
+ std::string InstanceId; // For use in notifications
+ std::string CoordinatorEndpoint;
+ std::string IdmsEndpoint;
+ int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2)
+ bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link
+
+# if ZEN_WITH_HORDE
+ horde::HordeConfig HordeConfig;
+# endif
+
+# if ZEN_WITH_NOMAD
+ nomad::NomadConfig NomadConfig;
+# endif
+};
+
+struct ZenComputeServerConfigurator : public ZenServerConfiguratorBase
+{
+ ZenComputeServerConfigurator(ZenComputeServerConfig& ServerOptions)
+ : ZenServerConfiguratorBase(ServerOptions)
+ , m_ServerOptions(ServerOptions)
+ {
+ }
+
+ ~ZenComputeServerConfigurator() = default;
+
+private:
+ virtual void AddCliOptions(cxxopts::Options& Options) override;
+ virtual void AddConfigOptions(LuaConfig::Options& Options) override;
+ virtual void ApplyOptions(cxxopts::Options& Options) override;
+ virtual void OnConfigFileParsed(LuaConfig::Options& LuaOptions) override;
+ virtual void ValidateOptions() override;
+
+ ZenComputeServerConfig& m_ServerOptions;
+
+# if ZEN_WITH_HORDE
+ std::string m_HordeModeStr = "direct";
+ std::string m_HordeEncryptionStr = "none";
+# endif
+
+# if ZEN_WITH_NOMAD
+ std::string m_NomadDriverStr = "raw_exec";
+ std::string m_NomadDistributionStr = "predeployed";
+# endif
+};
+
+class ZenComputeServerMain : public ZenServerMain
+{
+public:
+ ZenComputeServerMain(ZenComputeServerConfig& ServerOptions);
+ virtual void DoRun(ZenServerState::ZenServerEntry* Entry) override;
+
+ ZenComputeServerMain(const ZenComputeServerMain&) = delete;
+ ZenComputeServerMain& operator=(const ZenComputeServerMain&) = delete;
+
+ typedef ZenComputeServerConfig Config;
+ typedef ZenComputeServerConfigurator Configurator;
+
+private:
+ ZenComputeServerConfig& m_ServerOptions;
+};
+
+/**
+ * The compute server handles DDC build function execution requests
+ * only. It's intended to be used on a pure compute resource and does
+ * not handle any storage tasks. The actual scheduling happens upstream
+ * in a storage server instance.
+ */
+
+class ZenComputeServer : public ZenServerBase
+{
+ ZenComputeServer& operator=(ZenComputeServer&&) = delete;
+ ZenComputeServer(ZenComputeServer&&) = delete;
+
+public:
+ ZenComputeServer();
+ ~ZenComputeServer();
+
+ int Initialize(const ZenComputeServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry);
+ void Run();
+ void Cleanup();
+
+private:
+ GcManager m_GcManager;
+ GcScheduler m_GcScheduler{m_GcManager};
+ std::unique_ptr<CidStore> m_CidStore;
+ std::unique_ptr<HttpApiService> m_ApiService;
+ std::unique_ptr<zen::compute::HttpComputeService> m_ComputeService;
+ std::unique_ptr<zen::compute::HttpOrchestratorService> m_OrchestratorService;
+ std::unique_ptr<zen::compute::CloudMetadata> m_CloudMetadata;
+ std::future<std::unique_ptr<zen::compute::CloudMetadata>> m_CloudMetadataFuture;
+ std::unique_ptr<HttpFrontendService> m_FrontendService;
+# if ZEN_WITH_HORDE
+ std::unique_ptr<zen::horde::HordeProvisioner> m_HordeProvisioner;
+# endif
+# if ZEN_WITH_NOMAD
+ std::unique_ptr<zen::nomad::NomadProvisioner> m_NomadProvisioner;
+# endif
+ SystemMetricsTracker m_MetricsTracker;
+ std::string m_CoordinatorEndpoint;
+ std::string m_InstanceId;
+
+ asio::steady_timer m_AnnounceTimer{m_IoContext};
+ asio::steady_timer m_ProvisionerMaintenanceTimer{m_IoContext};
+
+ void InitializeState(const ZenComputeServerConfig& ServerConfig);
+ void InitializeServices(const ZenComputeServerConfig& ServerConfig);
+ void RegisterServices(const ZenComputeServerConfig& ServerConfig);
+ void ResolveCloudMetadata();
+ void PostAnnounce();
+ void EnqueueAnnounceTimer();
+ void EnqueueProvisionerMaintenanceTimer();
+ void ProvisionerMaintenanceTick();
+ std::string GetAnnounceUrl() const;
+ std::string GetInstanceId() const;
+ CbObject BuildAnnounceBody();
+
+ // Worker→orchestrator WebSocket client
+ struct OrchestratorWsHandler : public IWsClientHandler
+ {
+ ZenComputeServer& Server;
+ explicit OrchestratorWsHandler(ZenComputeServer& S) : Server(S) {}
+
+ void OnWsOpen() override;
+ void OnWsMessage(const WebSocketMessage& Msg) override;
+ void OnWsClose(uint16_t Code, std::string_view Reason) override;
+ };
+
+ std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler;
+ std::unique_ptr<HttpWsClient> m_OrchestratorWsClient;
+ asio::steady_timer m_WsReconnectTimer{m_IoContext};
+ bool m_EnableWorkerWebSocket = false;
+
+ void InitializeOrchestratorWebSocket();
+ void EnqueueWsReconnect();
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp
index 07913e891..e36352dae 100644
--- a/src/zenserver/config/config.cpp
+++ b/src/zenserver/config/config.cpp
@@ -16,8 +16,8 @@
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
#include <zencore/string.h>
-#include <zenutil/commandlineoptions.h>
-#include <zenutil/environmentoptions.h>
+#include <zenutil/config/commandlineoptions.h>
+#include <zenutil/config/environmentoptions.h>
ZEN_THIRD_PARTY_INCLUDES_START
#include <fmt/format.h>
@@ -119,10 +119,17 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions
ZenServerConfig& ServerOptions = m_ServerOptions;
+ // logging
+
+ LuaOptions.AddOption("server.logid"sv, ServerOptions.LoggingConfig.LogId, "log-id"sv);
+ LuaOptions.AddOption("server.abslog"sv, ServerOptions.LoggingConfig.AbsLogFile, "abslog"sv);
+ LuaOptions.AddOption("server.otlpendpoint"sv, ServerOptions.LoggingConfig.OtelEndpointUri, "otlp-endpoint"sv);
+ LuaOptions.AddOption("server.quiet"sv, ServerOptions.LoggingConfig.QuietConsole, "quiet"sv);
+ LuaOptions.AddOption("server.noconsole"sv, ServerOptions.LoggingConfig.NoConsoleOutput, "noconsole"sv);
+
// server
LuaOptions.AddOption("server.dedicated"sv, ServerOptions.IsDedicated, "dedicated"sv);
- LuaOptions.AddOption("server.logid"sv, ServerOptions.LogId, "log-id"sv);
LuaOptions.AddOption("server.sentry.disable"sv, ServerOptions.SentryConfig.Disable, "no-sentry"sv);
LuaOptions.AddOption("server.sentry.allowpersonalinfo"sv, ServerOptions.SentryConfig.AllowPII, "sentry-allow-personal-info"sv);
LuaOptions.AddOption("server.sentry.dsn"sv, ServerOptions.SentryConfig.Dsn, "sentry-dsn"sv);
@@ -131,12 +138,9 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions
LuaOptions.AddOption("server.systemrootdir"sv, ServerOptions.SystemRootDir, "system-dir"sv);
LuaOptions.AddOption("server.datadir"sv, ServerOptions.DataDir, "data-dir"sv);
LuaOptions.AddOption("server.contentdir"sv, ServerOptions.ContentDir, "content-dir"sv);
- LuaOptions.AddOption("server.abslog"sv, ServerOptions.AbsLogFile, "abslog"sv);
- LuaOptions.AddOption("server.otlpendpoint"sv, ServerOptions.OtelEndpointUri, "otlp-endpoint"sv);
LuaOptions.AddOption("server.debug"sv, ServerOptions.IsDebug, "debug"sv);
LuaOptions.AddOption("server.clean"sv, ServerOptions.IsCleanStart, "clean"sv);
- LuaOptions.AddOption("server.quiet"sv, ServerOptions.QuietConsole, "quiet"sv);
- LuaOptions.AddOption("server.noconsole"sv, ServerOptions.NoConsoleOutput, "noconsole"sv);
+ LuaOptions.AddOption("server.security.configpath"sv, ServerOptions.SecurityConfigPath, "security-config-path"sv);
////// network
@@ -182,8 +186,10 @@ struct ZenServerCmdLineOptions
std::string SystemRootDir;
std::string ContentDir;
std::string DataDir;
- std::string AbsLogFile;
std::string BaseSnapshotDir;
+ std::string SecurityConfigPath;
+
+ ZenLoggingCmdLineOptions LoggingOptions;
void AddCliOptions(cxxopts::Options& options, ZenServerConfig& ServerOptions);
void ApplyOptions(cxxopts::Options& options, ZenServerConfig& ServerOptions);
@@ -249,22 +255,7 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi
cxxopts::value<bool>(ServerOptions.ShouldCrash)->default_value("false"),
"");
- // clang-format off
- options.add_options("logging")
- ("abslog", "Path to log file", cxxopts::value<std::string>(AbsLogFile))
- ("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(ServerOptions.LogId))
- ("quiet", "Configure console logger output to level WARN", cxxopts::value<bool>(ServerOptions.QuietConsole)->default_value("false"))
- ("noconsole", "Disable console logging", cxxopts::value<bool>(ServerOptions.NoConsoleOutput)->default_value("false"))
- ("log-trace", "Change selected loggers to level TRACE", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Trace]))
- ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Debug]))
- ("log-info", "Change selected loggers to level INFO", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Info]))
- ("log-warn", "Change selected loggers to level WARN", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Warn]))
- ("log-error", "Change selected loggers to level ERROR", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Err]))
- ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Critical]))
- ("log-off", "Change selected loggers to level OFF", cxxopts::value<std::string>(ServerOptions.Loggers[logging::level::Off]))
- ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value<std::string>(ServerOptions.OtelEndpointUri))
- ;
- // clang-format on
+ LoggingOptions.AddCliOptions(options, ServerOptions.LoggingConfig);
options
.add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value<int>(ServerOptions.OwnerPid), "<identifier>");
@@ -311,6 +302,13 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi
cxxopts::value<bool>(ServerOptions.HttpConfig.ForceLoopback)->default_value("false"),
"<http forceloopback>");
+ options.add_option("network",
+ "",
+ "security-config-path",
+ "Path to http security configuration file",
+ cxxopts::value<std::string>(SecurityConfigPath),
+ "<security config path>");
+
#if ZEN_WITH_HTTPSYS
options.add_option("httpsys",
"",
@@ -391,12 +389,14 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig
throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir));
}
- ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir);
- ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir);
- ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir);
- ServerOptions.AbsLogFile = MakeSafeAbsolutePath(AbsLogFile);
- ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile);
- ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir);
+ ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir);
+ ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir);
+ ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir);
+ ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile);
+ ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir);
+ ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath);
+
+ LoggingOptions.ApplyOptions(ServerOptions.LoggingConfig);
}
//////////////////////////////////////////////////////////////////////////
@@ -466,34 +466,7 @@ ZenServerConfiguratorBase::Configure(int argc, char* argv[])
}
#endif
- if (m_ServerOptions.QuietConsole)
- {
- bool HasExplicitConsoleLevel = false;
- for (int i = 0; i < logging::level::LogLevelCount; ++i)
- {
- if (m_ServerOptions.Loggers[i].find("console") != std::string::npos)
- {
- HasExplicitConsoleLevel = true;
- break;
- }
- }
-
- if (!HasExplicitConsoleLevel)
- {
- std::string& WarnLoggers = m_ServerOptions.Loggers[logging::level::Warn];
- if (!WarnLoggers.empty())
- {
- WarnLoggers += ",";
- }
- WarnLoggers += "console";
- }
- }
-
- for (int i = 0; i < logging::level::LogLevelCount; ++i)
- {
- logging::ConfigureLogLevels(logging::level::LogLevel(i), m_ServerOptions.Loggers[i]);
- }
- logging::RefreshLogLevels();
+ ApplyLoggingOptions(options, m_ServerOptions.LoggingConfig);
BaseOptions.ApplyOptions(options, m_ServerOptions);
ApplyOptions(options);
@@ -532,9 +505,9 @@ ZenServerConfiguratorBase::Configure(int argc, char* argv[])
m_ServerOptions.DataDir = PickDefaultStateDirectory(m_ServerOptions.SystemRootDir);
}
- if (m_ServerOptions.AbsLogFile.empty())
+ if (m_ServerOptions.LoggingConfig.AbsLogFile.empty())
{
- m_ServerOptions.AbsLogFile = m_ServerOptions.DataDir / "logs" / "zenserver.log";
+ m_ServerOptions.LoggingConfig.AbsLogFile = m_ServerOptions.DataDir / "logs" / "zenserver.log";
}
m_ServerOptions.HttpConfig.IsDedicatedServer = m_ServerOptions.IsDedicated;
diff --git a/src/zenserver/config/config.h b/src/zenserver/config/config.h
index 7c3192a1f..55aee07f9 100644
--- a/src/zenserver/config/config.h
+++ b/src/zenserver/config/config.h
@@ -6,6 +6,7 @@
#include <zencore/trace.h>
#include <zencore/zencore.h>
#include <zenhttp/httpserver.h>
+#include <zenutil/config/loggingconfig.h>
#include <filesystem>
#include <string>
#include <vector>
@@ -42,29 +43,26 @@ struct ZenServerConfig
HttpServerConfig HttpConfig;
ZenSentryConfig SentryConfig;
ZenStatsConfig StatsConfig;
- int BasePort = 8558; // Service listen port (used for both UDP and TCP)
- int OwnerPid = 0; // Parent process id (zero for standalone)
- bool IsDebug = false;
- bool IsCleanStart = false; // Indicates whether all state should be wiped on startup or not
- bool IsPowerCycle = false; // When true, the process shuts down immediately after initialization
- bool IsTest = false;
- bool Detach = true; // Whether zenserver should detach from existing process group (Mac/Linux)
- bool NoConsoleOutput = false; // Control default use of stdout for diagnostics
- bool QuietConsole = false; // Configure console logger output to level WARN
- int CoreLimit = 0; // If set, hardware concurrency queries are capped at this number
- bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements
- bool ShouldCrash = false; // Option for testing crash handling
- bool IsFirstRun = false;
- std::filesystem::path ConfigFile; // Path to Lua config file
- std::filesystem::path SystemRootDir; // System root directory (used for machine level config)
- std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental)
- std::filesystem::path DataDir; // Root directory for state (used for testing)
- std::filesystem::path AbsLogFile; // Absolute path to main log file
- std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start)
- std::string ChildId; // Id assigned by parent process (used for lifetime management)
- std::string LogId; // Id for tagging log output
- std::string Loggers[zen::logging::level::LogLevelCount];
- std::string OtelEndpointUri; // OpenTelemetry endpoint URI
+ ZenLoggingConfig LoggingConfig;
+ int BasePort = 8558; // Service listen port (used for both UDP and TCP)
+ int OwnerPid = 0; // Parent process id (zero for standalone)
+ bool IsDebug = false;
+ bool IsCleanStart = false; // Indicates whether all state should be wiped on startup or not
+ bool IsPowerCycle = false; // When true, the process shuts down immediately after initialization
+ bool IsTest = false;
+ bool Detach = true; // Whether zenserver should detach from existing process group (Mac/Linux)
+ int CoreLimit = 0; // If set, hardware concurrency queries are capped at this number
+ int LieCpu = 0;
+ bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements
+ bool ShouldCrash = false; // Option for testing crash handling
+ bool IsFirstRun = false;
+ std::filesystem::path ConfigFile; // Path to Lua config file
+ std::filesystem::path SystemRootDir; // System root directory (used for machine level config)
+ std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental)
+ std::filesystem::path DataDir; // Root directory for state (used for testing)
+ std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start)
+ std::string ChildId; // Id assigned by parent process (used for lifetime management)
+ std::filesystem::path SecurityConfigPath; // Path to a Json security configuration file
#if ZEN_WITH_TRACE
bool HasTraceCommandlineOptions = false;
diff --git a/src/zenserver/config/luaconfig.h b/src/zenserver/config/luaconfig.h
index ce7013a9a..e3ac3b343 100644
--- a/src/zenserver/config/luaconfig.h
+++ b/src/zenserver/config/luaconfig.h
@@ -4,7 +4,7 @@
#include <zenbase/concepts.h>
#include <zencore/fmtutils.h>
-#include <zenutil/commandlineoptions.h>
+#include <zenutil/config/commandlineoptions.h>
ZEN_THIRD_PARTY_INCLUDES_START
#include <fmt/format.h>
diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp
index d8d53b0e3..dd4b8956c 100644
--- a/src/zenserver/diag/diagsvcs.cpp
+++ b/src/zenserver/diag/diagsvcs.cpp
@@ -9,12 +9,11 @@
#include <zencore/logging.h>
#include <zencore/memory/llm.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <fstream>
#include <sstream>
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/logger.h>
-ZEN_THIRD_PARTY_INCLUDES_END
+#include <zencore/logging/logger.h>
namespace zen {
@@ -53,6 +52,36 @@ HttpHealthService::HttpHealthService()
Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string();
Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion;
Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass;
+ Writer << "Port"sv << m_HealthInfo.Port;
+ Writer << "Pid"sv << m_HealthInfo.Pid;
+ Writer << "IsDedicated"sv << m_HealthInfo.IsDedicated;
+ Writer << "StartTimeMs"sv << m_HealthInfo.StartTimeMs;
+ }
+
+ Writer.BeginObject("RuntimeConfig"sv);
+ for (const auto& Opt : m_HealthInfo.RuntimeConfig)
+ {
+ Writer << Opt.first << Opt.second;
+ }
+ Writer.EndObject();
+
+ Writer.BeginObject("BuildConfig"sv);
+ for (const auto& Opt : m_HealthInfo.BuildOptions)
+ {
+ Writer << Opt.first << Opt.second;
+ }
+ Writer.EndObject();
+
+ Writer << "Hostname"sv << GetMachineName();
+ Writer << "Platform"sv << GetRuntimePlatformName();
+ Writer << "Arch"sv << GetCpuName();
+ Writer << "OS"sv << GetOperatingSystemVersion();
+
+ {
+ auto Metrics = GetSystemMetrics();
+ Writer.BeginObject("System"sv);
+ Describe(Metrics, Writer);
+ Writer.EndObject();
}
HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save());
@@ -64,7 +93,7 @@ HttpHealthService::HttpHealthService()
[this](HttpRouterRequest& RoutedReq) {
HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
- zen::Log().SpdLogger->flush();
+ zen::Log().Flush();
std::filesystem::path Path = [&] {
RwLock::SharedLockScope _(m_InfoLock);
diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h
index 8cc869c83..87ce80b3c 100644
--- a/src/zenserver/diag/diagsvcs.h
+++ b/src/zenserver/diag/diagsvcs.h
@@ -6,6 +6,7 @@
#include <zenhttp/httpserver.h>
#include <filesystem>
+#include <vector>
//////////////////////////////////////////////////////////////////////////
@@ -89,10 +90,16 @@ private:
struct HealthServiceInfo
{
- std::filesystem::path DataRoot;
- std::filesystem::path AbsLogPath;
- std::string HttpServerClass;
- std::string BuildVersion;
+ std::filesystem::path DataRoot;
+ std::filesystem::path AbsLogPath;
+ std::string HttpServerClass;
+ std::string BuildVersion;
+ int Port = 0;
+ int Pid = 0;
+ bool IsDedicated = false;
+ int64_t StartTimeMs = 0;
+ std::vector<std::pair<std::string_view, bool>> BuildOptions;
+ std::vector<std::pair<std::string_view, std::string>> RuntimeConfig;
};
/** Health monitoring endpoint
diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp
index 4962b9006..178c3d3b5 100644
--- a/src/zenserver/diag/logging.cpp
+++ b/src/zenserver/diag/logging.cpp
@@ -6,6 +6,8 @@
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
+#include <zencore/logging/logger.h>
+#include <zencore/logging/registry.h>
#include <zencore/memory/llm.h>
#include <zencore/session.h>
#include <zencore/string.h>
@@ -14,10 +16,6 @@
#include "otlphttp.h"
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/spdlog.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
namespace zen {
void
@@ -28,10 +26,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService)
const LoggingOptions LogOptions = {.IsDebug = InOptions.IsDebug,
.IsVerbose = false,
.IsTest = InOptions.IsTest,
- .NoConsoleOutput = InOptions.NoConsoleOutput,
- .QuietConsole = InOptions.QuietConsole,
- .AbsLogFile = InOptions.AbsLogFile,
- .LogId = InOptions.LogId};
+ .NoConsoleOutput = InOptions.LoggingConfig.NoConsoleOutput,
+ .QuietConsole = InOptions.LoggingConfig.QuietConsole,
+ .AbsLogFile = InOptions.LoggingConfig.AbsLogFile,
+ .LogId = InOptions.LoggingConfig.LogId};
BeginInitializeLogging(LogOptions);
@@ -43,13 +41,12 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService)
std::filesystem::path HttpLogPath = InOptions.DataDir / "logs" / "http.log";
zen::CreateDirectories(HttpLogPath.parent_path());
- auto HttpSink = std::make_shared<zen::logging::RotatingFileSink>(HttpLogPath,
- /* max size */ 128 * 1024 * 1024,
- /* max files */ 16,
- /* rotate on open */ true);
- auto HttpLogger = std::make_shared<spdlog::logger>("http_requests", HttpSink);
- spdlog::apply_logger_env_levels(HttpLogger);
- spdlog::register_logger(HttpLogger);
+ logging::SinkPtr HttpSink(new zen::logging::RotatingFileSink(HttpLogPath,
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ true));
+ Ref<logging::Logger> HttpLogger(new logging::Logger("http_requests", std::vector<logging::SinkPtr>{HttpSink}));
+ logging::Registry::Instance().Register(HttpLogger);
if (WithCacheService)
{
@@ -57,33 +54,30 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService)
std::filesystem::path CacheLogPath = InOptions.DataDir / "logs" / "z$.log";
zen::CreateDirectories(CacheLogPath.parent_path());
- auto CacheSink = std::make_shared<zen::logging::RotatingFileSink>(CacheLogPath,
- /* max size */ 128 * 1024 * 1024,
- /* max files */ 16,
- /* rotate on open */ false);
- auto CacheLogger = std::make_shared<spdlog::logger>("z$", CacheSink);
- spdlog::apply_logger_env_levels(CacheLogger);
- spdlog::register_logger(CacheLogger);
+ logging::SinkPtr CacheSink(new zen::logging::RotatingFileSink(CacheLogPath,
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ false));
+ Ref<logging::Logger> CacheLogger(new logging::Logger("z$", std::vector<logging::SinkPtr>{CacheSink}));
+ logging::Registry::Instance().Register(CacheLogger);
// Jupiter - only log upstream HTTP traffic to file
- auto JupiterLogger = std::make_shared<spdlog::logger>("jupiter", FileSink);
- spdlog::apply_logger_env_levels(JupiterLogger);
- spdlog::register_logger(JupiterLogger);
+ Ref<logging::Logger> JupiterLogger(new logging::Logger("jupiter", std::vector<logging::SinkPtr>{FileSink}));
+ logging::Registry::Instance().Register(JupiterLogger);
// Zen - only log upstream HTTP traffic to file
- auto ZenClientLogger = std::make_shared<spdlog::logger>("zenclient", FileSink);
- spdlog::apply_logger_env_levels(ZenClientLogger);
- spdlog::register_logger(ZenClientLogger);
+ Ref<logging::Logger> ZenClientLogger(new logging::Logger("zenclient", std::vector<logging::SinkPtr>{FileSink}));
+ logging::Registry::Instance().Register(ZenClientLogger);
}
#if ZEN_WITH_OTEL
- if (!InOptions.OtelEndpointUri.empty())
+ if (!InOptions.LoggingConfig.OtelEndpointUri.empty())
{
// TODO: Should sanity check that endpoint is reachable? Also, a valid URI?
- auto OtelSink = std::make_shared<zen::logging::OtelHttpProtobufSink>(InOptions.OtelEndpointUri);
- zen::logging::Default().SpdLogger->sinks().push_back(std::move(OtelSink));
+ logging::SinkPtr OtelSink(new zen::logging::OtelHttpProtobufSink(InOptions.LoggingConfig.OtelEndpointUri));
+ zen::logging::Default()->AddSink(std::move(OtelSink));
}
#endif
@@ -91,9 +85,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService)
const zen::Oid ServerSessionId = zen::GetSessionId();
- spdlog::apply_all([&](auto Logger) {
+ static constinit logging::LogPoint SessionIdPoint{{}, logging::Info, "server session id: {}"};
+ logging::Registry::Instance().ApplyAll([&](auto Logger) {
ZEN_MEMSCOPE(ELLMTag::Logging);
- Logger->info("server session id: {}", ServerSessionId);
+ Logger->Log(SessionIdPoint, fmt::make_format_args(ServerSessionId));
});
}
diff --git a/src/zenserver/diag/otlphttp.cpp b/src/zenserver/diag/otlphttp.cpp
index d62ccccb6..d6e24cbe3 100644
--- a/src/zenserver/diag/otlphttp.cpp
+++ b/src/zenserver/diag/otlphttp.cpp
@@ -10,11 +10,18 @@
#include <protozero/buffer_string.hpp>
#include <protozero/pbf_builder.hpp>
+#include <cstdio>
+
#if ZEN_WITH_OTEL
namespace zen::logging {
//////////////////////////////////////////////////////////////////////////
+//
+// Important note: in general we cannot use ZEN_WARN/ZEN_ERROR etc in this
+// file as it could cause recursive logging calls when we attempt to log
+// errors from the OTLP HTTP client itself.
+//
OtelHttpProtobufSink::OtelHttpProtobufSink(const std::string_view& Uri) : m_OtelHttp(Uri)
{
@@ -36,14 +43,44 @@ OtelHttpProtobufSink::~OtelHttpProtobufSink()
}
void
+OtelHttpProtobufSink::CheckPostResult(const HttpClient::Response& Result, const char* Endpoint) noexcept
+{
+ if (!Result.IsSuccess())
+ {
+ uint32_t PrevFailures = m_ConsecutivePostFailures.fetch_add(1);
+ if (PrevFailures < kMaxReportedFailures)
+ {
+ fprintf(stderr, "OtelHttpProtobufSink: %s\n", Result.ErrorMessage(Endpoint).c_str());
+ if (PrevFailures + 1 == kMaxReportedFailures)
+ {
+ fprintf(stderr, "OtelHttpProtobufSink: suppressing further export errors\n");
+ }
+ }
+ }
+ else
+ {
+ m_ConsecutivePostFailures.store(0);
+ }
+}
+
+void
OtelHttpProtobufSink::RecordSpans(zen::otel::TraceId Trace, std::span<const zen::otel::Span*> Spans)
{
- std::string Data = m_Encoder.FormatOtelTrace(Trace, Spans);
+ try
+ {
+ std::string Data = m_Encoder.FormatOtelTrace(Trace, Spans);
+
+ IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()};
+ Payload.SetContentType(ZenContentType::kProtobuf);
- IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()};
- Payload.SetContentType(ZenContentType::kProtobuf);
+ HttpClient::Response Result = m_OtelHttp.Post("/v1/traces", Payload);
- auto Result = m_OtelHttp.Post("/v1/traces", Payload);
+ CheckPostResult(Result, "POST /v1/traces");
+ }
+ catch (const std::exception& Ex)
+ {
+ fprintf(stderr, "OtelHttpProtobufSink: exception exporting traces: %s\n", Ex.what());
+ }
}
void
@@ -53,28 +90,26 @@ OtelHttpProtobufSink::TraceRecorder::RecordSpans(zen::otel::TraceId Trace, std::
}
void
-OtelHttpProtobufSink::log(const spdlog::details::log_msg& Msg)
+OtelHttpProtobufSink::Log(const LogMessage& Msg)
{
+ try
{
std::string Data = m_Encoder.FormatOtelProtobuf(Msg);
IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()};
Payload.SetContentType(ZenContentType::kProtobuf);
- auto Result = m_OtelHttp.Post("/v1/logs", Payload);
- }
+ HttpClient::Response Result = m_OtelHttp.Post("/v1/logs", Payload);
+ CheckPostResult(Result, "POST /v1/logs");
+ }
+ catch (const std::exception& Ex)
{
- std::string Data = m_Encoder.FormatOtelMetrics();
-
- IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()};
- Payload.SetContentType(ZenContentType::kProtobuf);
-
- auto Result = m_OtelHttp.Post("/v1/metrics", Payload);
+ fprintf(stderr, "OtelHttpProtobufSink: exception exporting logs: %s\n", Ex.what());
}
}
void
-OtelHttpProtobufSink::flush()
+OtelHttpProtobufSink::Flush()
{
}
diff --git a/src/zenserver/diag/otlphttp.h b/src/zenserver/diag/otlphttp.h
index 2281bdcc0..64b3dbc87 100644
--- a/src/zenserver/diag/otlphttp.h
+++ b/src/zenserver/diag/otlphttp.h
@@ -3,23 +3,25 @@
#pragma once
-#include <spdlog/sinks/sink.h>
+#include <zencore/logging/sink.h>
#include <zencore/zencore.h>
#include <zenhttp/httpclient.h>
#include <zentelemetry/otlpencoder.h>
#include <zentelemetry/otlptrace.h>
+#include <atomic>
+
#if ZEN_WITH_OTEL
namespace zen::logging {
/**
- * OTLP/HTTP sink for spdlog
+ * OTLP/HTTP sink for logging
*
* Sends log messages and traces to an OpenTelemetry collector via OTLP over HTTP
*/
-class OtelHttpProtobufSink : public spdlog::sinks::sink
+class OtelHttpProtobufSink : public Sink
{
public:
// Note that this URI should be the base URI of the OTLP HTTP endpoint, e.g.
@@ -31,12 +33,12 @@ public:
OtelHttpProtobufSink& operator=(const OtelHttpProtobufSink&) = delete;
private:
- virtual void log(const spdlog::details::log_msg& Msg) override;
- virtual void flush() override;
- virtual void set_pattern(const std::string& pattern) override { ZEN_UNUSED(pattern); }
- virtual void set_formatter(std::unique_ptr<spdlog::formatter> sink_formatter) override { ZEN_UNUSED(sink_formatter); }
+ virtual void Log(const LogMessage& Msg) override;
+ virtual void Flush() override;
+ virtual void SetFormatter(std::unique_ptr<Formatter>) override {}
void RecordSpans(zen::otel::TraceId Trace, std::span<const zen::otel::Span*> Spans);
+ void CheckPostResult(const HttpClient::Response& Result, const char* Endpoint) noexcept;
// This is just a thin wrapper to call back into the sink while participating in
// reference counting from the OTEL trace back-end
@@ -54,11 +56,15 @@ private:
OtelHttpProtobufSink* m_Sink;
};
- HttpClient m_OtelHttp;
- OtlpEncoder m_Encoder;
- Ref<TraceRecorder> m_TraceRecorder;
+ static constexpr uint32_t kMaxReportedFailures = 5;
+
+ RwLock m_Lock;
+ std::atomic<uint32_t> m_ConsecutivePostFailures{0};
+ HttpClient m_OtelHttp;
+ OtlpEncoder m_Encoder;
+ Ref<TraceRecorder> m_TraceRecorder;
};
} // namespace zen::logging
-#endif \ No newline at end of file
+#endif
diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp
index 2b157581f..579a65c5a 100644
--- a/src/zenserver/frontend/frontend.cpp
+++ b/src/zenserver/frontend/frontend.cpp
@@ -38,7 +38,7 @@ HttpFrontendService::HttpFrontendService(std::filesystem::path Directory, HttpSt
#if ZEN_EMBED_HTML_ZIP
// Load an embedded Zip archive
IoBuffer HtmlZipDataBuffer(IoBuffer::Wrap, gHtmlZipData, sizeof(gHtmlZipData) - 1);
- m_ZipFs = ZipFs(std::move(HtmlZipDataBuffer));
+ m_ZipFs = std::make_unique<ZipFs>(std::move(HtmlZipDataBuffer));
#endif
if (m_Directory.empty() && !m_ZipFs)
@@ -114,6 +114,8 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
{
using namespace std::literals;
+ ExtendableStringBuilder<256> UriBuilder;
+
std::string_view Uri = Request.RelativeUriWithExtension();
for (; Uri.length() > 0 && Uri[0] == '/'; Uri = Uri.substr(1))
;
@@ -121,6 +123,11 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
{
Uri = "index.html"sv;
}
+ else if (Uri.back() == '/')
+ {
+ UriBuilder << Uri << "index.html"sv;
+ Uri = UriBuilder;
+ }
// Dismiss if the URI contains .. anywhere to prevent arbitrary file reads
if (Uri.find("..") != Uri.npos)
@@ -145,24 +152,47 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
return Request.WriteResponse(HttpResponseCode::Forbidden);
}
- // The given content directory overrides any zip-fs discovered in the binary
- if (!m_Directory.empty())
- {
- auto FullPath = m_Directory / std::filesystem::path(Uri).make_preferred();
- FileContents File = ReadFile(FullPath);
+ auto WriteResponseForUri = [this,
+ &Request](std::string_view InUri, HttpResponseCode ResponseCode, HttpContentType ContentType) -> bool {
+ // The given content directory overrides any zip-fs discovered in the binary
+ if (!m_Directory.empty())
+ {
+ auto FullPath = m_Directory / std::filesystem::path(InUri).make_preferred();
+ FileContents File = ReadFile(FullPath);
- if (!File.ErrorCode)
+ if (!File.ErrorCode)
+ {
+ Request.WriteResponse(ResponseCode, ContentType, File.Data[0]);
+
+ return true;
+ }
+ }
+
+ if (m_ZipFs)
{
- return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]);
+ if (IoBuffer FileBuffer = m_ZipFs->GetFile(InUri))
+ {
+ Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer);
+
+ return true;
+ }
}
- }
- if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri))
+ return false;
+ };
+
+ if (WriteResponseForUri(Uri, HttpResponseCode::OK, ContentType))
{
- return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer);
+ return;
+ }
+ else if (WriteResponseForUri("404.html"sv, HttpResponseCode::NotFound, HttpContentType::kHTML))
+ {
+ return;
+ }
+ else
+ {
+ Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
}
-
- Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
}
} // namespace zen
diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h
index 84ffaac42..6d8585b72 100644
--- a/src/zenserver/frontend/frontend.h
+++ b/src/zenserver/frontend/frontend.h
@@ -7,6 +7,7 @@
#include "zipfs.h"
#include <filesystem>
+#include <memory>
namespace zen {
@@ -20,9 +21,9 @@ public:
virtual void HandleStatusRequest(HttpServerRequest& Request) override;
private:
- ZipFs m_ZipFs;
- std::filesystem::path m_Directory;
- HttpStatusService& m_StatusService;
+ std::unique_ptr<ZipFs> m_ZipFs;
+ std::filesystem::path m_Directory;
+ HttpStatusService& m_StatusService;
};
} // namespace zen
diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip
index 5d33302dd..84472ff08 100644
--- a/src/zenserver/frontend/html.zip
+++ b/src/zenserver/frontend/html.zip
Binary files differ
diff --git a/src/zenserver/frontend/html/404.html b/src/zenserver/frontend/html/404.html
new file mode 100644
index 000000000..829ef2097
--- /dev/null
+++ b/src/zenserver/frontend/html/404.html
@@ -0,0 +1,486 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+<meta charset="UTF-8">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
+<title>Ooops</title>
+<style>
+ * { margin: 0; padding: 0; box-sizing: border-box; }
+
+ :root {
+ --deep-space: #00000f;
+ --nebula-blue: #0a0a2e;
+ --star-white: #ffffff;
+ --star-blue: #c8d8ff;
+ --star-yellow: #fff3c0;
+ --star-red: #ffd0c0;
+ --nebula-glow: rgba(60, 80, 180, 0.12);
+ }
+
+ body {
+ background: var(--deep-space);
+ min-height: 100vh;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font-family: 'Courier New', monospace;
+ overflow: hidden;
+ }
+
+ starfield-bg {
+ display: block;
+ position: fixed;
+ inset: 0;
+ z-index: 0;
+ }
+
+ canvas {
+ display: block;
+ width: 100%;
+ height: 100%;
+ }
+
+ .page-content {
+ position: relative;
+ z-index: 1;
+ text-align: center;
+ color: rgba(200, 216, 255, 0.85);
+ letter-spacing: 0.25em;
+ text-transform: uppercase;
+ pointer-events: none;
+ user-select: none;
+ }
+
+ .page-content h1 {
+ font-size: clamp(1.2rem, 4vw, 2.4rem);
+ font-weight: 300;
+ letter-spacing: 0.6em;
+ text-shadow: 0 0 40px rgba(120, 160, 255, 0.6), 0 0 80px rgba(80, 120, 255, 0.3);
+ animation: pulse 6s ease-in-out infinite;
+ }
+
+ .page-content p {
+ margin-top: 1.2rem;
+ font-size: clamp(0.55rem, 1.5vw, 0.75rem);
+ letter-spacing: 0.4em;
+ opacity: 0.45;
+ }
+
+ @keyframes pulse {
+ 0%, 100% { opacity: 0.7; }
+ 50% { opacity: 1; }
+ }
+
+ .globe-link {
+ display: block;
+ margin: 0 auto 2rem;
+ width: 160px;
+ height: 160px;
+ pointer-events: auto;
+ cursor: pointer;
+ border-radius: 50%;
+ position: relative;
+ }
+
+ .globe-link:hover .globe-glow {
+ opacity: 0.6;
+ }
+
+ .globe-glow {
+ position: absolute;
+ inset: -18px;
+ border-radius: 50%;
+ background: radial-gradient(circle, rgba(80, 140, 255, 0.35) 0%, transparent 70%);
+ opacity: 0.35;
+ transition: opacity 0.4s;
+ pointer-events: none;
+ }
+
+ .globe-link canvas {
+ display: block;
+ width: 160px;
+ height: 160px;
+ border-radius: 50%;
+ }
+</style>
+</head>
+<body>
+
+<starfield-bg
+ star-count="380"
+ speed="0.6"
+ depth="true"
+ nebula="true"
+ shooting-stars="true"
+></starfield-bg>
+
+<div class="page-content">
+ <a class="globe-link" href="/dashboard/" title="Back to Dashboard">
+ <div class="globe-glow"></div>
+ <canvas id="globe" width="320" height="320"></canvas>
+ </a>
+ <h1>404 NOT FOUND</h1>
+</div>
+
+<script>
+class StarfieldBg extends HTMLElement {
+ constructor() {
+ super();
+ this.attachShadow({ mode: 'open' });
+ }
+
+ connectedCallback() {
+ this.shadowRoot.innerHTML = `
+ <style>
+ :host { display: block; position: absolute; inset: 0; overflow: hidden; }
+ canvas { width: 100%; height: 100%; display: block; }
+ </style>
+ <canvas></canvas>
+ `;
+
+ this.canvas = this.shadowRoot.querySelector('canvas');
+ this.ctx = this.canvas.getContext('2d');
+
+ this.starCount = parseInt(this.getAttribute('star-count') || '350');
+ this.speed = parseFloat(this.getAttribute('speed') || '0.6');
+ this.useDepth = this.getAttribute('depth') !== 'false';
+ this.useNebula = this.getAttribute('nebula') !== 'false';
+ this.useShooting = this.getAttribute('shooting-stars') !== 'false';
+
+ this.stars = [];
+ this.shooters = [];
+ this.nebulaTime = 0;
+ this.frame = 0;
+
+ this.resize();
+ this.init();
+
+ this._ro = new ResizeObserver(() => { this.resize(); this.init(); });
+ this._ro.observe(this);
+
+ this.raf = requestAnimationFrame(this.tick.bind(this));
+ }
+
+ disconnectedCallback() {
+ cancelAnimationFrame(this.raf);
+ this._ro.disconnect();
+ }
+
+ resize() {
+ const dpr = window.devicePixelRatio || 1;
+ const rect = this.getBoundingClientRect();
+ this.W = rect.width || window.innerWidth;
+ this.H = rect.height || window.innerHeight;
+ this.canvas.width = this.W * dpr;
+ this.canvas.height = this.H * dpr;
+ this.ctx.setTransform(dpr, 0, 0, dpr, 0, 0);
+ }
+
+ init() {
+ const COLORS = ['#ffffff', '#c8d8ff', '#d0e8ff', '#fff3c0', '#ffd0c0', '#e0f0ff'];
+ this.stars = Array.from({ length: this.starCount }, () => ({
+ x: Math.random() * this.W,
+ y: Math.random() * this.H,
+ z: this.useDepth ? Math.random() : 1, // depth: 0=far, 1=near
+ r: Math.random() * 1.4 + 0.2,
+ color: COLORS[Math.floor(Math.random() * COLORS.length)],
+ twinkleOffset: Math.random() * Math.PI * 2,
+ twinkleSpeed: 0.008 + Math.random() * 0.012,
+ }));
+ }
+
+ spawnShooter() {
+ const edge = Math.random() < 0.7 ? 'top' : 'left';
+ const angle = (Math.random() * 30 + 15) * (Math.PI / 180);
+ this.shooters.push({
+ x: edge === 'top' ? Math.random() * this.W : -10,
+ y: edge === 'top' ? -10 : Math.random() * this.H * 0.5,
+ vx: Math.cos(angle) * (6 + Math.random() * 6),
+ vy: Math.sin(angle) * (6 + Math.random() * 6),
+ len: 80 + Math.random() * 120,
+ life: 1,
+ decay: 0.012 + Math.random() * 0.018,
+ });
+ }
+
+ tick() {
+ this.raf = requestAnimationFrame(this.tick.bind(this));
+ this.frame++;
+ const ctx = this.ctx;
+ const W = this.W, H = this.H;
+
+ // Background
+ ctx.fillStyle = '#00000f';
+ ctx.fillRect(0, 0, W, H);
+
+ // Nebula clouds (subtle)
+ if (this.useNebula) {
+ this.nebulaTime += 0.003;
+ this.drawNebula(ctx, W, H);
+ }
+
+ // Stars
+ for (const s of this.stars) {
+ const twinkle = 0.55 + 0.45 * Math.sin(this.frame * s.twinkleSpeed + s.twinkleOffset);
+ const radius = s.r * (this.useDepth ? (0.3 + s.z * 0.7) : 1);
+ const alpha = (this.useDepth ? (0.25 + s.z * 0.75) : 1) * twinkle;
+
+ // Tiny drift
+ s.x += (s.z * this.speed * 0.08) * (this.useDepth ? 1 : 0);
+ s.y += (s.z * this.speed * 0.04) * (this.useDepth ? 1 : 0);
+ if (s.x > W + 2) s.x = -2;
+ if (s.y > H + 2) s.y = -2;
+
+ // Glow for bright stars
+ if (radius > 1.1 && alpha > 0.6) {
+ const grd = ctx.createRadialGradient(s.x, s.y, 0, s.x, s.y, radius * 3.5);
+ grd.addColorStop(0, s.color.replace(')', `, ${alpha * 0.5})`).replace('rgb', 'rgba'));
+ grd.addColorStop(1, 'transparent');
+ ctx.beginPath();
+ ctx.arc(s.x, s.y, radius * 3.5, 0, Math.PI * 2);
+ ctx.fillStyle = grd;
+ ctx.fill();
+ }
+
+ ctx.beginPath();
+ ctx.arc(s.x, s.y, radius, 0, Math.PI * 2);
+ ctx.fillStyle = hexToRgba(s.color, alpha);
+ ctx.fill();
+ }
+
+ // Shooting stars
+ if (this.useShooting) {
+ if (this.frame % 140 === 0 && Math.random() < 0.65) this.spawnShooter();
+ for (let i = this.shooters.length - 1; i >= 0; i--) {
+ const s = this.shooters[i];
+ const tailX = s.x - s.vx * (s.len / Math.hypot(s.vx, s.vy));
+ const tailY = s.y - s.vy * (s.len / Math.hypot(s.vx, s.vy));
+
+ const grd = ctx.createLinearGradient(tailX, tailY, s.x, s.y);
+ grd.addColorStop(0, `rgba(255,255,255,0)`);
+ grd.addColorStop(0.7, `rgba(200,220,255,${s.life * 0.5})`);
+ grd.addColorStop(1, `rgba(255,255,255,${s.life})`);
+
+ ctx.beginPath();
+ ctx.moveTo(tailX, tailY);
+ ctx.lineTo(s.x, s.y);
+ ctx.strokeStyle = grd;
+ ctx.lineWidth = 1.5 * s.life;
+ ctx.lineCap = 'round';
+ ctx.stroke();
+
+ // Head dot
+ ctx.beginPath();
+ ctx.arc(s.x, s.y, 1.5 * s.life, 0, Math.PI * 2);
+ ctx.fillStyle = `rgba(255,255,255,${s.life})`;
+ ctx.fill();
+
+ s.x += s.vx;
+ s.y += s.vy;
+ s.life -= s.decay;
+
+ if (s.life <= 0 || s.x > W + 200 || s.y > H + 200) {
+ this.shooters.splice(i, 1);
+ }
+ }
+ }
+ }
+
+ drawNebula(ctx, W, H) {
+ const t = this.nebulaTime;
+ const blobs = [
+ { x: W * 0.25, y: H * 0.3, rx: W * 0.35, ry: H * 0.25, color: '40,60,180', a: 0.055 },
+ { x: W * 0.75, y: H * 0.65, rx: W * 0.30, ry: H * 0.22, color: '100,40,160', a: 0.04 },
+ { x: W * 0.5, y: H * 0.5, rx: W * 0.45, ry: H * 0.35, color: '20,50,120', a: 0.035 },
+ ];
+ ctx.save();
+ for (const b of blobs) {
+ const ox = Math.sin(t * 0.7 + b.x) * 30;
+ const oy = Math.cos(t * 0.5 + b.y) * 20;
+ const grd = ctx.createRadialGradient(b.x + ox, b.y + oy, 0, b.x + ox, b.y + oy, Math.max(b.rx, b.ry));
+ grd.addColorStop(0, `rgba(${b.color}, ${b.a})`);
+ grd.addColorStop(0.5, `rgba(${b.color}, ${b.a * 0.4})`);
+ grd.addColorStop(1, `rgba(${b.color}, 0)`);
+ ctx.save();
+ ctx.scale(b.rx / Math.max(b.rx, b.ry), b.ry / Math.max(b.rx, b.ry));
+ ctx.beginPath();
+ const scale = Math.max(b.rx, b.ry);
+ ctx.arc((b.x + ox) / (b.rx / scale), (b.y + oy) / (b.ry / scale), scale, 0, Math.PI * 2);
+ ctx.fillStyle = grd;
+ ctx.fill();
+ ctx.restore();
+ }
+ ctx.restore();
+ }
+}
+
+function hexToRgba(hex, alpha) {
+ // Handle named-ish values or full hex
+ const c = hex.startsWith('#') ? hex : '#ffffff';
+ const r = parseInt(c.slice(1,3), 16);
+ const g = parseInt(c.slice(3,5), 16);
+ const b = parseInt(c.slice(5,7), 16);
+ return `rgba(${r},${g},${b},${alpha.toFixed(3)})`;
+}
+
+customElements.define('starfield-bg', StarfieldBg);
+</script>
+
+<script>
+(function() {
+ const canvas = document.getElementById('globe');
+ const ctx = canvas.getContext('2d');
+ const W = canvas.width, H = canvas.height;
+ const R = W * 0.44;
+ const cx = W / 2, cy = H / 2;
+
+ // Simplified continent outlines as lon/lat polygon chains (degrees).
+ // Each continent is an array of [lon, lat] points.
+ const continents = [
+ // North America
+ [[-130,50],[-125,55],[-120,60],[-115,65],[-100,68],[-85,70],[-75,65],[-60,52],[-65,45],[-70,42],[-75,35],[-80,30],[-85,28],[-90,28],[-95,25],[-100,20],[-105,20],[-110,25],[-115,30],[-120,35],[-125,42],[-130,50]],
+ // South America
+ [[-80,10],[-75,5],[-70,5],[-65,0],[-60,-5],[-55,-5],[-50,-10],[-45,-15],[-40,-20],[-40,-25],[-42,-30],[-48,-32],[-52,-34],[-55,-38],[-60,-42],[-65,-50],[-68,-55],[-70,-48],[-72,-40],[-75,-30],[-78,-15],[-80,-5],[-80,5],[-80,10]],
+ // Europe
+ [[-10,36],[-5,38],[0,40],[2,43],[5,44],[8,46],[10,48],[15,50],[18,54],[20,56],[25,58],[28,60],[30,62],[35,65],[40,68],[38,60],[35,55],[30,50],[28,48],[25,45],[22,40],[20,38],[15,36],[10,36],[5,36],[0,36],[-5,36],[-10,36]],
+ // Africa
+ [[-15,14],[-17,16],[-15,22],[-12,28],[-5,32],[0,35],[5,37],[10,35],[15,32],[20,30],[25,30],[30,28],[35,25],[38,18],[40,12],[42,5],[44,0],[42,-5],[40,-12],[38,-18],[35,-25],[32,-30],[30,-34],[25,-33],[20,-30],[15,-28],[12,-20],[10,-10],[8,-5],[5,0],[2,5],[0,5],[-5,5],[-10,6],[-15,10],[-15,14]],
+ // Asia (simplified)
+ [[30,35],[35,38],[40,40],[45,42],[50,45],[55,48],[60,50],[65,55],[70,60],[75,65],[80,68],[90,70],[100,68],[110,65],[120,60],[125,55],[130,50],[135,45],[140,40],[138,35],[130,30],[120,25],[110,20],[105,15],[100,10],[95,12],[90,20],[85,22],[80,25],[75,28],[70,30],[65,35],[55,35],[45,35],[40,35],[35,35],[30,35]],
+ // Australia
+ [[115,-12],[120,-14],[125,-15],[130,-14],[135,-13],[138,-16],[140,-18],[145,-20],[148,-22],[150,-25],[152,-28],[150,-33],[148,-35],[145,-37],[140,-38],[135,-36],[130,-33],[125,-30],[120,-25],[118,-22],[116,-20],[114,-18],[115,-15],[115,-12]],
+ ];
+
+ function project(lon, lat, rotation) {
+ // Convert to radians and apply rotation
+ var lonR = (lon + rotation) * Math.PI / 180;
+ var latR = lat * Math.PI / 180;
+
+ var x3 = Math.cos(latR) * Math.sin(lonR);
+ var y3 = -Math.sin(latR);
+ var z3 = Math.cos(latR) * Math.cos(lonR);
+
+ // Only visible if facing us
+ if (z3 < 0) return null;
+
+ return { x: cx + x3 * R, y: cy + y3 * R, z: z3 };
+ }
+
+ var rotation = 0;
+
+ function draw() {
+ requestAnimationFrame(draw);
+ rotation += 0.15;
+ ctx.clearRect(0, 0, W, H);
+
+ // Atmosphere glow
+ var atm = ctx.createRadialGradient(cx, cy, R * 0.85, cx, cy, R * 1.15);
+ atm.addColorStop(0, 'rgba(60,130,255,0.12)');
+ atm.addColorStop(0.5, 'rgba(60,130,255,0.06)');
+ atm.addColorStop(1, 'rgba(60,130,255,0)');
+ ctx.beginPath();
+ ctx.arc(cx, cy, R * 1.15, 0, Math.PI * 2);
+ ctx.fillStyle = atm;
+ ctx.fill();
+
+ // Ocean sphere
+ var oceanGrad = ctx.createRadialGradient(cx - R * 0.3, cy - R * 0.3, R * 0.1, cx, cy, R);
+ oceanGrad.addColorStop(0, '#1a4a8a');
+ oceanGrad.addColorStop(0.5, '#0e2d5e');
+ oceanGrad.addColorStop(1, '#071838');
+ ctx.beginPath();
+ ctx.arc(cx, cy, R, 0, Math.PI * 2);
+ ctx.fillStyle = oceanGrad;
+ ctx.fill();
+
+ // Draw continents
+ for (var c = 0; c < continents.length; c++) {
+ var pts = continents[c];
+ var projected = [];
+ var allVisible = true;
+
+ for (var i = 0; i < pts.length; i++) {
+ var p = project(pts[i][0], pts[i][1], rotation);
+ if (!p) { allVisible = false; break; }
+ projected.push(p);
+ }
+
+ if (!allVisible || projected.length < 3) continue;
+
+ ctx.beginPath();
+ ctx.moveTo(projected[0].x, projected[0].y);
+ for (var i = 1; i < projected.length; i++) {
+ ctx.lineTo(projected[i].x, projected[i].y);
+ }
+ ctx.closePath();
+
+ // Shade based on average depth
+ var avgZ = 0;
+ for (var i = 0; i < projected.length; i++) avgZ += projected[i].z;
+ avgZ /= projected.length;
+ var brightness = 0.3 + avgZ * 0.7;
+
+ var r = Math.round(30 * brightness);
+ var g = Math.round(100 * brightness);
+ var b = Math.round(50 * brightness);
+ ctx.fillStyle = 'rgb(' + r + ',' + g + ',' + b + ')';
+ ctx.fill();
+ }
+
+ // Grid lines (longitude)
+ ctx.strokeStyle = 'rgba(100,160,255,0.08)';
+ ctx.lineWidth = 0.7;
+ for (var lon = -180; lon < 180; lon += 30) {
+ ctx.beginPath();
+ var started = false;
+ for (var lat = -90; lat <= 90; lat += 3) {
+ var p = project(lon, lat, rotation);
+ if (p) {
+ if (!started) { ctx.moveTo(p.x, p.y); started = true; }
+ else ctx.lineTo(p.x, p.y);
+ } else {
+ started = false;
+ }
+ }
+ ctx.stroke();
+ }
+
+ // Grid lines (latitude)
+ for (var lat = -60; lat <= 60; lat += 30) {
+ ctx.beginPath();
+ var started = false;
+ for (var lon = -180; lon <= 180; lon += 3) {
+ var p = project(lon, lat, rotation);
+ if (p) {
+ if (!started) { ctx.moveTo(p.x, p.y); started = true; }
+ else ctx.lineTo(p.x, p.y);
+ } else {
+ started = false;
+ }
+ }
+ ctx.stroke();
+ }
+
+ // Specular highlight
+ var spec = ctx.createRadialGradient(cx - R * 0.35, cy - R * 0.35, 0, cx - R * 0.35, cy - R * 0.35, R * 0.8);
+ spec.addColorStop(0, 'rgba(180,210,255,0.18)');
+ spec.addColorStop(0.4, 'rgba(120,160,255,0.05)');
+ spec.addColorStop(1, 'rgba(0,0,0,0)');
+ ctx.beginPath();
+ ctx.arc(cx, cy, R, 0, Math.PI * 2);
+ ctx.fillStyle = spec;
+ ctx.fill();
+
+ // Rim light
+ ctx.beginPath();
+ ctx.arc(cx, cy, R, 0, Math.PI * 2);
+ ctx.strokeStyle = 'rgba(80,140,255,0.2)';
+ ctx.lineWidth = 1.5;
+ ctx.stroke();
+ }
+
+ draw();
+})();
+</script>
+</body>
+</html>
diff --git a/src/zenserver/frontend/html/banner.js b/src/zenserver/frontend/html/banner.js
new file mode 100644
index 000000000..2e878dedf
--- /dev/null
+++ b/src/zenserver/frontend/html/banner.js
@@ -0,0 +1,338 @@
+/**
+ * zen-banner.js — Zen dashboard banner Web Component
+ *
+ * Usage:
+ * <script src="banner.js" defer></script>
+ *
+ * <zen-banner></zen-banner>
+ * <zen-banner variant="compact"></zen-banner>
+ * <zen-banner cluster-status="degraded" load="78"></zen-banner>
+ *
+ * Attributes:
+ * variant "full" (default) | "compact"
+ * cluster-status "nominal" (default) | "degraded" | "offline"
+ * load 0–100 integer, shown as a percentage (default: hidden)
+ * tagline custom tagline text (default: "Orchestrator Overview" / "Orchestrator")
+ * subtitle text after "ZEN" in the wordmark (default: "COMPUTE")
+ */
+
+class ZenBanner extends HTMLElement {
+
+ static get observedAttributes() {
+ return ['variant', 'cluster-status', 'load', 'tagline', 'subtitle', 'logo-src'];
+ }
+
+ attributeChangedCallback() {
+ if (this.shadowRoot) this._render();
+ }
+
+ connectedCallback() {
+ if (!this.shadowRoot) this.attachShadow({ mode: 'open' });
+ this._render();
+ }
+
+ // ─────────────────────────────────────────────
+ // Derived values
+ // ─────────────────────────────────────────────
+
+ get _variant() { return this.getAttribute('variant') || 'full'; }
+ get _status() { return (this.getAttribute('cluster-status') || 'nominal').toLowerCase(); }
+ get _load() { return this.getAttribute('load'); } // null → hidden
+ get _tagline() { return this.getAttribute('tagline'); } // null → default
+ get _subtitle() { return this.getAttribute('subtitle'); } // null → "COMPUTE"
+ get _logoSrc() { return this.getAttribute('logo-src'); } // null → inline SVG
+
+ get _statusColor() {
+ return { nominal: '#7ecfb8', degraded: '#d4a84b', offline: '#c0504d' }[this._status] ?? '#7ecfb8';
+ }
+
+ get _statusLabel() {
+ return { nominal: 'NOMINAL', degraded: 'DEGRADED', offline: 'OFFLINE' }[this._status] ?? 'NOMINAL';
+ }
+
+ get _loadColor() {
+ const v = parseInt(this._load, 10);
+ if (isNaN(v)) return '#7ecfb8';
+ if (v >= 85) return '#c0504d';
+ if (v >= 60) return '#d4a84b';
+ return '#7ecfb8';
+ }
+
+ // ─────────────────────────────────────────────
+ // Render
+ // ─────────────────────────────────────────────
+
+ _render() {
+ const compact = this._variant === 'compact';
+ this.shadowRoot.innerHTML = `
+ <style>${this._css(compact)}</style>
+ ${this._html(compact)}
+ `;
+ }
+
+ // ─────────────────────────────────────────────
+ // CSS
+ // ─────────────────────────────────────────────
+
+ _css(compact) {
+ const height = compact ? '60px' : '100px';
+ const padding = compact ? '0 24px' : '0 32px';
+ const gap = compact ? '16px' : '24px';
+ const markSize = compact ? '34px' : '52px';
+ const divH = compact ? '32px' : '48px';
+ const nameSize = compact ? '15px' : '22px';
+ const tagSize = compact ? '9px' : '11px';
+ const sc = this._statusColor;
+ const lc = this._loadColor;
+
+ return `
+ @import url('https://fonts.googleapis.com/css2?family=Noto+Serif+JP:wght@300;400&family=Space+Mono:wght@400;700&display=swap');
+
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
+
+ :host {
+ display: block;
+ font-family: 'Space Mono', monospace;
+ }
+
+ .banner {
+ width: 100%;
+ height: ${height};
+ background: var(--theme_g3, #0b0d10);
+ border: 1px solid var(--theme_g2, #1e2330);
+ border-radius: 6px;
+ display: flex;
+ align-items: center;
+ padding: ${padding};
+ gap: ${gap};
+ position: relative;
+ overflow: hidden;
+ text-decoration: none;
+ color: inherit;
+ cursor: pointer;
+ }
+
+ /* scan-line texture */
+ .banner::before {
+ content: '';
+ position: absolute;
+ inset: 0;
+ background: repeating-linear-gradient(
+ 0deg,
+ transparent, transparent 3px,
+ rgba(255,255,255,0.012) 3px, rgba(255,255,255,0.012) 4px
+ );
+ pointer-events: none;
+ }
+
+ /* ambient glow */
+ .banner::after {
+ content: '';
+ position: absolute;
+ right: -60px;
+ top: 50%;
+ transform: translateY(-50%);
+ width: 280px;
+ height: 280px;
+ background: radial-gradient(circle, rgba(130,200,180,0.06) 0%, transparent 70%);
+ pointer-events: none;
+ }
+
+ .logo-mark {
+ flex-shrink: 0;
+ width: ${markSize};
+ height: ${markSize};
+ }
+
+ .logo-mark svg, .logo-mark img { width: 100%; height: 100%; object-fit: contain; }
+
+ .divider {
+ width: 1px;
+ height: ${divH};
+ background: linear-gradient(to bottom, transparent, var(--theme_g2, #2a3040), transparent);
+ flex-shrink: 0;
+ }
+
+ .text-block {
+ display: flex;
+ flex-direction: column;
+ gap: 4px;
+ }
+
+ .wordmark {
+ font-weight: 700;
+ font-size: ${nameSize};
+ letter-spacing: 0.12em;
+ color: var(--theme_bright, #e8e4dc);
+ text-transform: uppercase;
+ line-height: 1;
+ }
+
+ .wordmark span { color: #7ecfb8; }
+
+ .tagline {
+ font-family: 'Noto Serif JP', serif;
+ font-weight: 300;
+ font-size: ${tagSize};
+ letter-spacing: 0.3em;
+ color: var(--theme_faint, #4a5a68);
+ text-transform: uppercase;
+ }
+
+ .spacer { flex: 1; }
+
+ /* ── right-side decorative circuit ── */
+ .circuit { flex-shrink: 0; opacity: 0.22; }
+
+ /* ── status cluster ── */
+ .status-cluster {
+ display: flex;
+ flex-direction: column;
+ align-items: flex-end;
+ gap: 6px;
+ }
+
+ .status-row {
+ display: flex;
+ align-items: center;
+ gap: 8px;
+ }
+
+ .status-lbl {
+ font-size: 9px;
+ letter-spacing: 0.18em;
+ color: var(--theme_faint, #3a4555);
+ text-transform: uppercase;
+ }
+
+ .pill {
+ display: flex;
+ align-items: center;
+ gap: 5px;
+ border-radius: 20px;
+ padding: 2px 10px;
+ font-size: 10px;
+ letter-spacing: 0.1em;
+ }
+
+ .pill.cluster {
+ color: ${sc};
+ background: color-mix(in srgb, ${sc} 8%, transparent);
+ border: 1px solid color-mix(in srgb, ${sc} 28%, transparent);
+ }
+
+ .pill.load-pill {
+ color: ${lc};
+ background: color-mix(in srgb, ${lc} 8%, transparent);
+ border: 1px solid color-mix(in srgb, ${lc} 28%, transparent);
+ }
+
+ .dot {
+ width: 5px;
+ height: 5px;
+ border-radius: 50%;
+ animation: pulse 2.4s ease-in-out infinite;
+ }
+
+ .dot.cluster { background: ${sc}; }
+ .dot.load-dot { background: ${lc}; animation-delay: 0.5s; }
+
+ @keyframes pulse {
+ 0%, 100% { opacity: 1; }
+ 50% { opacity: 0.25; }
+ }
+ `;
+ }
+
+ // ─────────────────────────────────────────────
+ // HTML template
+ // ─────────────────────────────────────────────
+
+ _html(compact) {
+ const loadAttr = this._load;
+ const hasCluster = !compact && this.hasAttribute('cluster-status');
+ const hasLoad = !compact && loadAttr !== null;
+ const showRight = hasCluster || hasLoad;
+
+ const circuit = showRight ? `
+ <svg class="circuit" width="60" height="60" viewBox="0 0 60 60" fill="none">
+ <path d="M5 30 H22 L28 18 H60" stroke="#7ecfb8" stroke-width="0.8"/>
+ <path d="M5 38 H18 L24 46 H60" stroke="#7ecfb8" stroke-width="0.8"/>
+ <circle cx="22" cy="30" r="2" fill="none" stroke="#7ecfb8" stroke-width="0.8"/>
+ <circle cx="18" cy="38" r="2" fill="none" stroke="#7ecfb8" stroke-width="0.8"/>
+ <circle cx="10" cy="30" r="1.2" fill="#7ecfb8"/>
+ <circle cx="10" cy="38" r="1.2" fill="#7ecfb8"/>
+ </svg>` : '';
+
+ const clusterRow = hasCluster ? `
+ <div class="status-row">
+ <span class="status-lbl">Cluster</span>
+ <div class="pill cluster">
+ <div class="dot cluster"></div>
+ ${this._statusLabel}
+ </div>
+ </div>` : '';
+
+ const loadRow = hasLoad ? `
+ <div class="status-row">
+ <span class="status-lbl">Load</span>
+ <div class="pill load-pill">
+ <div class="dot load-dot"></div>
+ ${parseInt(loadAttr, 10)} %
+ </div>
+ </div>` : '';
+
+ const rightSide = showRight ? `
+ ${circuit}
+ <div class="status-cluster">
+ ${clusterRow}
+ ${loadRow}
+ </div>
+ ` : '';
+
+ return `
+ <a class="banner" href="/dashboard/">
+ <div class="logo-mark">${this._logoMark()}</div>
+ <div class="divider"></div>
+ <div class="text-block">
+ <div class="wordmark">ZEN<span> ${this._subtitle ?? 'COMPUTE'}</span></div>
+ <div class="tagline">${this._tagline ?? (compact ? 'Orchestrator' : 'Orchestrator Overview')}</div>
+ </div>
+ <div class="spacer"></div>
+ ${rightSide}
+ </a>
+ `;
+ }
+
+ // ─────────────────────────────────────────────
+ // SVG logo mark
+ // ─────────────────────────────────────────────
+
+ _logoMark() {
+ const src = this._logoSrc;
+ if (src) {
+ return `<img src="${src}" alt="zen">`;
+ }
+ return `
+ <svg viewBox="0 0 52 52" fill="none" xmlns="http://www.w3.org/2000/svg">
+ <circle cx="26" cy="26" r="22" stroke="#2a3a48" stroke-width="1.5"/>
+ <path d="M26 4 A22 22 0 1 1 12 43.1" stroke="#7ecfb8" stroke-width="2" stroke-linecap="round" fill="none"/>
+ <circle cx="17" cy="17" r="1.6" fill="#7ecfb8" />
+ <circle cx="26" cy="17" r="1.6" fill="#7ecfb8" />
+ <circle cx="35" cy="17" r="1.6" fill="#7ecfb8" />
+ <circle cx="17" cy="26" r="1.6" fill="#7ecfb8" opacity="0.6"/>
+ <circle cx="26" cy="26" r="2.2" fill="#7ecfb8"/>
+ <circle cx="35" cy="26" r="1.6" fill="#7ecfb8" opacity="0.6"/>
+ <circle cx="17" cy="35" r="1.6" fill="#7ecfb8"/>
+ <circle cx="26" cy="35" r="1.6" fill="#7ecfb8"/>
+ <circle cx="35" cy="35" r="1.6" fill="#7ecfb8"/>
+ <line x1="17" y1="17" x2="35" y2="17" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.25"/>
+ <line x1="35" y1="17" x2="17" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.25"/>
+ <line x1="17" y1="35" x2="35" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.2"/>
+ <line x1="26" y1="17" x2="26" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.2"/>
+ </svg>
+ `;
+ }
+}
+
+customElements.define('zen-banner', ZenBanner);
diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html
new file mode 100644
index 000000000..66c20175f
--- /dev/null
+++ b/src/zenserver/frontend/html/compute/compute.html
@@ -0,0 +1,929 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <title>Zen Compute Dashboard</title>
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script>
+ <link rel="stylesheet" type="text/css" href="../zen.css" />
+ <script src="../theme.js"></script>
+ <script src="../banner.js" defer></script>
+ <script src="../nav.js" defer></script>
+ <style>
+ .grid {
+ grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
+ }
+
+ .chart-container {
+ position: relative;
+ height: 300px;
+ margin-top: 20px;
+ }
+
+ .stats-row {
+ display: flex;
+ justify-content: space-between;
+ margin-bottom: 12px;
+ padding: 8px 0;
+ border-bottom: 1px solid var(--theme_border_subtle);
+ }
+
+ .stats-row:last-child {
+ border-bottom: none;
+ margin-bottom: 0;
+ }
+
+ .stats-label {
+ color: var(--theme_g1);
+ font-size: 13px;
+ }
+
+ .stats-value {
+ color: var(--theme_bright);
+ font-weight: 600;
+ font-size: 13px;
+ }
+
+ .rate-stats {
+ display: grid;
+ grid-template-columns: repeat(3, 1fr);
+ gap: 16px;
+ margin-top: 16px;
+ }
+
+ .rate-item {
+ text-align: center;
+ }
+
+ .rate-value {
+ font-size: 20px;
+ font-weight: 600;
+ color: var(--theme_p0);
+ }
+
+ .rate-label {
+ font-size: 11px;
+ color: var(--theme_g1);
+ margin-top: 4px;
+ text-transform: uppercase;
+ }
+
+ .worker-row {
+ cursor: pointer;
+ transition: background 0.15s;
+ }
+
+ .worker-row:hover {
+ background: var(--theme_p4);
+ }
+
+ .worker-row.selected {
+ background: var(--theme_p3);
+ }
+
+ .worker-detail {
+ margin-top: 20px;
+ border-top: 1px solid var(--theme_g2);
+ padding-top: 16px;
+ }
+
+ .worker-detail-title {
+ font-size: 15px;
+ font-weight: 600;
+ color: var(--theme_bright);
+ margin-bottom: 12px;
+ }
+
+ .detail-section {
+ margin-bottom: 16px;
+ }
+
+ .detail-section-label {
+ font-size: 11px;
+ font-weight: 600;
+ color: var(--theme_g1);
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+ margin-bottom: 6px;
+ }
+
+ .detail-table {
+ width: 100%;
+ border-collapse: collapse;
+ font-size: 12px;
+ }
+
+ .detail-table td {
+ padding: 4px 8px;
+ color: var(--theme_g0);
+ border-bottom: 1px solid var(--theme_border_subtle);
+ vertical-align: top;
+ }
+
+ .detail-table td:first-child {
+ color: var(--theme_g1);
+ width: 40%;
+ font-family: monospace;
+ }
+
+ .detail-table tr:last-child td {
+ border-bottom: none;
+ }
+
+ .detail-mono {
+ font-family: monospace;
+ font-size: 11px;
+ color: var(--theme_g1);
+ }
+
+ .detail-tag {
+ display: inline-block;
+ padding: 2px 8px;
+ border-radius: 4px;
+ background: var(--theme_border_subtle);
+ color: var(--theme_g0);
+ font-size: 11px;
+ margin: 2px 4px 2px 0;
+ }
+ </style>
+</head>
+<body>
+ <div class="container" style="max-width: 1400px; margin: 0 auto;">
+ <zen-banner cluster-status="nominal" load="0" tagline="Node Overview" logo-src="../favicon.ico"></zen-banner>
+ <zen-nav>
+ <a href="/dashboard/">Home</a>
+ <a href="compute.html">Node</a>
+ <a href="orchestrator.html">Orchestrator</a>
+ </zen-nav>
+ <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
+
+ <div id="error-container"></div>
+
+ <!-- Action Queue Stats -->
+ <div class="section-title">Action Queue</div>
+ <div class="grid">
+ <div class="card">
+ <div class="card-title">Pending Actions</div>
+ <div class="metric-value" id="actions-pending">-</div>
+ <div class="metric-label">Waiting to be scheduled</div>
+ </div>
+ <div class="card">
+ <div class="card-title">Running Actions</div>
+ <div class="metric-value" id="actions-running">-</div>
+ <div class="metric-label">Currently executing</div>
+ </div>
+ <div class="card">
+ <div class="card-title">Completed Actions</div>
+ <div class="metric-value" id="actions-complete">-</div>
+ <div class="metric-label">Results available</div>
+ </div>
+ </div>
+
+ <!-- Action Queue Chart -->
+ <div class="card" style="margin-bottom: 30px;">
+ <div class="card-title">Action Queue History</div>
+ <div class="chart-container">
+ <canvas id="queue-chart"></canvas>
+ </div>
+ </div>
+
+ <!-- Performance Metrics -->
+ <div class="section-title">Performance Metrics</div>
+ <div class="card" style="margin-bottom: 30px;">
+ <div class="card-title">Completion Rate</div>
+ <div class="rate-stats">
+ <div class="rate-item">
+ <div class="rate-value" id="rate-1">-</div>
+ <div class="rate-label">1 min rate</div>
+ </div>
+ <div class="rate-item">
+ <div class="rate-value" id="rate-5">-</div>
+ <div class="rate-label">5 min rate</div>
+ </div>
+ <div class="rate-item">
+ <div class="rate-value" id="rate-15">-</div>
+ <div class="rate-label">15 min rate</div>
+ </div>
+ </div>
+ <div style="margin-top: 20px;">
+ <div class="stats-row">
+ <span class="stats-label">Total Retired</span>
+ <span class="stats-value" id="retired-count">-</span>
+ </div>
+ <div class="stats-row">
+ <span class="stats-label">Mean Rate</span>
+ <span class="stats-value" id="rate-mean">-</span>
+ </div>
+ </div>
+ </div>
+
+ <!-- Workers -->
+ <div class="section-title">Workers</div>
+ <div class="card" style="margin-bottom: 30px;">
+ <div class="card-title">Worker Status</div>
+ <div class="stats-row">
+ <span class="stats-label">Registered Workers</span>
+ <span class="stats-value" id="worker-count">-</span>
+ </div>
+ <div id="worker-table-container" style="margin-top: 16px; display: none;">
+ <table id="worker-table">
+ <thead>
+ <tr>
+ <th>Name</th>
+ <th>Platform</th>
+ <th style="text-align: right;">Cores</th>
+ <th style="text-align: right;">Timeout</th>
+ <th style="text-align: right;">Functions</th>
+ <th>Worker ID</th>
+ </tr>
+ </thead>
+ <tbody id="worker-table-body"></tbody>
+ </table>
+ <div id="worker-detail" class="worker-detail" style="display: none;"></div>
+ </div>
+ </div>
+
+ <!-- Queues -->
+ <div class="section-title">Queues</div>
+ <div class="card" style="margin-bottom: 30px;">
+ <div class="card-title">Queue Status</div>
+ <div id="queue-list-empty" class="empty-state" style="text-align: left;">No queues.</div>
+ <div id="queue-list-container" style="display: none;">
+ <table id="queue-list-table">
+ <thead>
+ <tr>
+ <th style="text-align: right; width: 60px;">ID</th>
+ <th style="text-align: center; width: 80px;">Status</th>
+ <th style="text-align: right;">Active</th>
+ <th style="text-align: right;">Completed</th>
+ <th style="text-align: right;">Failed</th>
+ <th style="text-align: right;">Abandoned</th>
+ <th style="text-align: right;">Cancelled</th>
+ <th>Token</th>
+ </tr>
+ </thead>
+ <tbody id="queue-list-body"></tbody>
+ </table>
+ </div>
+ </div>
+
+ <!-- Action History -->
+ <div class="section-title">Recent Actions</div>
+ <div class="card" style="margin-bottom: 30px;">
+ <div class="card-title">Action History</div>
+ <div id="action-history-empty" class="empty-state" style="text-align: left;">No actions recorded yet.</div>
+ <div id="action-history-container" style="display: none;">
+ <table id="action-history-table">
+ <thead>
+ <tr>
+ <th style="text-align: right; width: 60px;">LSN</th>
+ <th style="text-align: right; width: 60px;">Queue</th>
+ <th style="text-align: center; width: 70px;">Status</th>
+ <th>Function</th>
+ <th style="text-align: right; width: 80px;">Started</th>
+ <th style="text-align: right; width: 80px;">Finished</th>
+ <th style="text-align: right; width: 80px;">Duration</th>
+ <th>Worker ID</th>
+ <th>Action ID</th>
+ </tr>
+ </thead>
+ <tbody id="action-history-body"></tbody>
+ </table>
+ </div>
+ </div>
+
+ <!-- System Resources -->
+ <div class="section-title">System Resources</div>
+ <div class="grid">
+ <div class="card">
+ <div class="card-title">CPU Usage</div>
+ <div class="metric-value" id="cpu-usage">-</div>
+ <div class="metric-label">Percent</div>
+ <div class="progress-bar">
+ <div class="progress-fill" id="cpu-progress" style="width: 0%"></div>
+ </div>
+ <div style="position: relative; height: 60px; margin-top: 12px;">
+ <canvas id="cpu-chart"></canvas>
+ </div>
+ <div style="margin-top: 12px;">
+ <div class="stats-row">
+ <span class="stats-label">Packages</span>
+ <span class="stats-value" id="cpu-packages">-</span>
+ </div>
+ <div class="stats-row">
+ <span class="stats-label">Physical Cores</span>
+ <span class="stats-value" id="cpu-cores">-</span>
+ </div>
+ <div class="stats-row">
+ <span class="stats-label">Logical Processors</span>
+ <span class="stats-value" id="cpu-lp">-</span>
+ </div>
+ </div>
+ </div>
+ <div class="card">
+ <div class="card-title">Memory</div>
+ <div class="stats-row">
+ <span class="stats-label">Used</span>
+ <span class="stats-value" id="memory-used">-</span>
+ </div>
+ <div class="stats-row">
+ <span class="stats-label">Total</span>
+ <span class="stats-value" id="memory-total">-</span>
+ </div>
+ <div class="progress-bar">
+ <div class="progress-fill" id="memory-progress" style="width: 0%"></div>
+ </div>
+ </div>
+ <div class="card">
+ <div class="card-title">Disk</div>
+ <div class="stats-row">
+ <span class="stats-label">Used</span>
+ <span class="stats-value" id="disk-used">-</span>
+ </div>
+ <div class="stats-row">
+ <span class="stats-label">Total</span>
+ <span class="stats-value" id="disk-total">-</span>
+ </div>
+ <div class="progress-bar">
+ <div class="progress-fill" id="disk-progress" style="width: 0%"></div>
+ </div>
+ </div>
+ </div>
+ </div>
+
+ <script>
+ // Configuration
+ const BASE_URL = window.location.origin;
+ const REFRESH_INTERVAL = 2000; // 2 seconds
+ const MAX_HISTORY_POINTS = 60; // Show last 2 minutes
+
+ // Data storage
+ const history = {
+ timestamps: [],
+ pending: [],
+ running: [],
+ completed: [],
+ cpu: []
+ };
+
+ // CPU sparkline chart
+ const cpuCtx = document.getElementById('cpu-chart').getContext('2d');
+ const cpuChart = new Chart(cpuCtx, {
+ type: 'line',
+ data: {
+ labels: [],
+ datasets: [{
+ data: [],
+ borderColor: '#58a6ff',
+ backgroundColor: 'rgba(88, 166, 255, 0.15)',
+ borderWidth: 1.5,
+ tension: 0.4,
+ fill: true,
+ pointRadius: 0
+ }]
+ },
+ options: {
+ responsive: true,
+ maintainAspectRatio: false,
+ animation: false,
+ plugins: { legend: { display: false }, tooltip: { enabled: false } },
+ scales: {
+ x: { display: false },
+ y: { display: false, min: 0, max: 100 }
+ }
+ }
+ });
+
+ // Queue chart setup
+ const ctx = document.getElementById('queue-chart').getContext('2d');
+ const chart = new Chart(ctx, {
+ type: 'line',
+ data: {
+ labels: [],
+ datasets: [
+ {
+ label: 'Pending',
+ data: [],
+ borderColor: '#f0883e',
+ backgroundColor: 'rgba(240, 136, 62, 0.1)',
+ tension: 0.4,
+ fill: true
+ },
+ {
+ label: 'Running',
+ data: [],
+ borderColor: '#58a6ff',
+ backgroundColor: 'rgba(88, 166, 255, 0.1)',
+ tension: 0.4,
+ fill: true
+ },
+ {
+ label: 'Completed',
+ data: [],
+ borderColor: '#3fb950',
+ backgroundColor: 'rgba(63, 185, 80, 0.1)',
+ tension: 0.4,
+ fill: true
+ }
+ ]
+ },
+ options: {
+ responsive: true,
+ maintainAspectRatio: false,
+ plugins: {
+ legend: {
+ display: true,
+ labels: {
+ color: '#8b949e'
+ }
+ }
+ },
+ scales: {
+ x: {
+ display: false
+ },
+ y: {
+ beginAtZero: true,
+ ticks: {
+ color: '#8b949e'
+ },
+ grid: {
+ color: '#21262d'
+ }
+ }
+ }
+ }
+ });
+
+ // Helper functions
+ function escapeHtml(text) {
+ var div = document.createElement('div');
+ div.textContent = text;
+ return div.innerHTML;
+ }
+
+ function formatBytes(bytes) {
+ if (bytes === 0) return '0 B';
+ const k = 1024;
+ const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
+ const i = Math.floor(Math.log(bytes) / Math.log(k));
+ return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
+ }
+
+ function formatRate(rate) {
+ return rate.toFixed(2) + '/s';
+ }
+
+ function showError(message) {
+ const container = document.getElementById('error-container');
+ container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`;
+ }
+
+ function clearError() {
+ document.getElementById('error-container').innerHTML = '';
+ }
+
+ function updateTimestamp() {
+ const now = new Date();
+ document.getElementById('last-update').textContent = now.toLocaleTimeString();
+ }
+
+ // Fetch functions
+ async function fetchJSON(endpoint) {
+ const response = await fetch(`${BASE_URL}${endpoint}`, {
+ headers: {
+ 'Accept': 'application/json'
+ }
+ });
+ if (!response.ok) {
+ throw new Error(`HTTP ${response.status}: ${response.statusText}`);
+ }
+ return await response.json();
+ }
+
+ async function fetchHealth() {
+ try {
+ const response = await fetch(`${BASE_URL}/compute/ready`);
+ const isHealthy = response.status === 200;
+
+ const banner = document.querySelector('zen-banner');
+
+ if (isHealthy) {
+ banner.setAttribute('cluster-status', 'nominal');
+ banner.setAttribute('load', '0');
+ } else {
+ banner.setAttribute('cluster-status', 'degraded');
+ banner.setAttribute('load', '0');
+ }
+
+ return isHealthy;
+ } catch (error) {
+ const banner = document.querySelector('zen-banner');
+ banner.setAttribute('cluster-status', 'degraded');
+ banner.setAttribute('load', '0');
+ throw error;
+ }
+ }
+
+ async function fetchStats() {
+ const data = await fetchJSON('/stats/compute');
+
+ // Update action counts
+ document.getElementById('actions-pending').textContent = data.actions_pending || 0;
+ document.getElementById('actions-running').textContent = data.actions_submitted || 0;
+ document.getElementById('actions-complete').textContent = data.actions_complete || 0;
+
+ // Update completion rates
+ if (data.actions_retired) {
+ document.getElementById('rate-1').textContent = formatRate(data.actions_retired.rate_1 || 0);
+ document.getElementById('rate-5').textContent = formatRate(data.actions_retired.rate_5 || 0);
+ document.getElementById('rate-15').textContent = formatRate(data.actions_retired.rate_15 || 0);
+ document.getElementById('retired-count').textContent = data.actions_retired.count || 0;
+ document.getElementById('rate-mean').textContent = formatRate(data.actions_retired.rate_mean || 0);
+ }
+
+ // Update chart
+ const now = new Date().toLocaleTimeString();
+ history.timestamps.push(now);
+ history.pending.push(data.actions_pending || 0);
+ history.running.push(data.actions_submitted || 0);
+ history.completed.push(data.actions_complete || 0);
+
+ // Keep only last N points
+ if (history.timestamps.length > MAX_HISTORY_POINTS) {
+ history.timestamps.shift();
+ history.pending.shift();
+ history.running.shift();
+ history.completed.shift();
+ }
+
+ chart.data.labels = history.timestamps;
+ chart.data.datasets[0].data = history.pending;
+ chart.data.datasets[1].data = history.running;
+ chart.data.datasets[2].data = history.completed;
+ chart.update('none');
+ }
+
+ async function fetchSysInfo() {
+ const data = await fetchJSON('/compute/sysinfo');
+
+ // Update CPU
+ const cpuUsage = data.cpu_usage || 0;
+ document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%';
+ document.getElementById('cpu-progress').style.width = cpuUsage + '%';
+
+ const banner = document.querySelector('zen-banner');
+ banner.setAttribute('load', cpuUsage.toFixed(1));
+
+ history.cpu.push(cpuUsage);
+ if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift();
+ cpuChart.data.labels = history.cpu.map(() => '');
+ cpuChart.data.datasets[0].data = history.cpu;
+ cpuChart.update('none');
+
+ document.getElementById('cpu-packages').textContent = data.cpu_count ?? '-';
+ document.getElementById('cpu-cores').textContent = data.core_count ?? '-';
+ document.getElementById('cpu-lp').textContent = data.lp_count ?? '-';
+
+ // Update Memory
+ const memUsed = data.memory_used || 0;
+ const memTotal = data.memory_total || 1;
+ const memPercent = (memUsed / memTotal) * 100;
+ document.getElementById('memory-used').textContent = formatBytes(memUsed);
+ document.getElementById('memory-total').textContent = formatBytes(memTotal);
+ document.getElementById('memory-progress').style.width = memPercent + '%';
+
+ // Update Disk
+ const diskUsed = data.disk_used || 0;
+ const diskTotal = data.disk_total || 1;
+ const diskPercent = (diskUsed / diskTotal) * 100;
+ document.getElementById('disk-used').textContent = formatBytes(diskUsed);
+ document.getElementById('disk-total').textContent = formatBytes(diskTotal);
+ document.getElementById('disk-progress').style.width = diskPercent + '%';
+ }
+
+ // Persists the selected worker ID across refreshes
+ let selectedWorkerId = null;
+
+ function renderWorkerDetail(id, desc) {
+ const panel = document.getElementById('worker-detail');
+
+ if (!desc) {
+ panel.style.display = 'none';
+ return;
+ }
+
+ function field(label, value) {
+ return `<tr><td>${label}</td><td>${value ?? '-'}</td></tr>`;
+ }
+
+ function monoField(label, value) {
+ return `<tr><td>${label}</td><td class="detail-mono">${value ?? '-'}</td></tr>`;
+ }
+
+ // Functions
+ const functions = desc.functions || [];
+ const functionsHtml = functions.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
+ `<table class="detail-table">${functions.map(f =>
+ `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>`
+ ).join('')}</table>`;
+
+ // Executables
+ const executables = desc.executables || [];
+ const totalExecSize = executables.reduce((sum, e) => sum + (e.size || 0), 0);
+ const execHtml = executables.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
+ `<table class="detail-table">
+ <tr style="font-size:11px;">
+ <td style="color:var(--theme_faint);padding-bottom:4px;">Path</td>
+ <td style="color:var(--theme_faint);padding-bottom:4px;">Hash</td>
+ <td style="color:var(--theme_faint);padding-bottom:4px;text-align:right;">Size</td>
+ </tr>
+ ${executables.map(e =>
+ `<tr>
+ <td>${escapeHtml(e.name || '-')}</td>
+ <td class="detail-mono">${escapeHtml(e.hash || '-')}</td>
+ <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td>
+ </tr>`
+ ).join('')}
+ <tr style="border-top:1px solid var(--theme_g2);">
+ <td style="color:var(--theme_g1);padding-top:6px;">Total</td>
+ <td></td>
+ <td style="text-align:right;white-space:nowrap;padding-top:6px;color:var(--theme_bright);font-weight:600;">${formatBytes(totalExecSize)}</td>
+ </tr>
+ </table>`;
+
+ // Files
+ const files = desc.files || [];
+ const filesHtml = files.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
+ `<table class="detail-table">${files.map(f =>
+ `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>`
+ ).join('')}</table>`;
+
+ // Dirs
+ const dirs = desc.dirs || [];
+ const dirsHtml = dirs.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
+ dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join('');
+
+ // Environment
+ const env = desc.environment || [];
+ const envHtml = env.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
+ env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join('');
+
+ panel.innerHTML = `
+ <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div>
+ <div class="detail-section">
+ <table class="detail-table">
+ ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)}
+ ${field('Path', escapeHtml(desc.path || '-'))}
+ ${field('Platform', escapeHtml(desc.host || '-'))}
+ ${monoField('Build System', desc.buildsystem_version)}
+ ${field('Cores', desc.cores)}
+ ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)}
+ </table>
+ </div>
+ <div class="detail-section">
+ <div class="detail-section-label">Functions</div>
+ ${functionsHtml}
+ </div>
+ <div class="detail-section">
+ <div class="detail-section-label">Executables</div>
+ ${execHtml}
+ </div>
+ <div class="detail-section">
+ <div class="detail-section-label">Files</div>
+ ${filesHtml}
+ </div>
+ <div class="detail-section">
+ <div class="detail-section-label">Directories</div>
+ ${dirsHtml}
+ </div>
+ <div class="detail-section">
+ <div class="detail-section-label">Environment</div>
+ ${envHtml}
+ </div>
+ `;
+ panel.style.display = 'block';
+ }
+
+ async function fetchWorkers() {
+ const data = await fetchJSON('/compute/workers');
+ const workerIds = data.workers || [];
+
+ document.getElementById('worker-count').textContent = workerIds.length;
+
+ const container = document.getElementById('worker-table-container');
+ const tbody = document.getElementById('worker-table-body');
+
+ if (workerIds.length === 0) {
+ container.style.display = 'none';
+ selectedWorkerId = null;
+ return;
+ }
+
+ const descriptors = await Promise.all(
+ workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null))
+ );
+
+ // Build a map for quick lookup by ID
+ const descriptorMap = {};
+ workerIds.forEach((id, i) => { descriptorMap[id] = descriptors[i]; });
+
+ tbody.innerHTML = '';
+ descriptors.forEach((desc, i) => {
+ const id = workerIds[i];
+ const name = desc ? (desc.name || '-') : '-';
+ const host = desc ? (desc.host || '-') : '-';
+ const cores = desc ? (desc.cores != null ? desc.cores : '-') : '-';
+ const timeout = desc ? (desc.timeout != null ? desc.timeout + 's' : '-') : '-';
+ const functions = desc ? (desc.functions ? desc.functions.length : 0) : '-';
+
+ const tr = document.createElement('tr');
+ tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : '');
+ tr.dataset.workerId = id;
+ tr.innerHTML = `
+ <td style="color: var(--theme_bright);">${escapeHtml(name)}</td>
+ <td>${escapeHtml(host)}</td>
+ <td style="text-align: right;">${escapeHtml(String(cores))}</td>
+ <td style="text-align: right;">${escapeHtml(String(timeout))}</td>
+ <td style="text-align: right;">${escapeHtml(String(functions))}</td>
+ <td style="color: var(--theme_g1); font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td>
+ `;
+ tr.addEventListener('click', () => {
+ document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected'));
+ if (selectedWorkerId === id) {
+ // Toggle off
+ selectedWorkerId = null;
+ document.getElementById('worker-detail').style.display = 'none';
+ } else {
+ selectedWorkerId = id;
+ tr.classList.add('selected');
+ renderWorkerDetail(id, descriptorMap[id]);
+ }
+ });
+ tbody.appendChild(tr);
+ });
+
+ // Re-render detail if selected worker is still present
+ if (selectedWorkerId && descriptorMap[selectedWorkerId]) {
+ renderWorkerDetail(selectedWorkerId, descriptorMap[selectedWorkerId]);
+ } else if (selectedWorkerId && !descriptorMap[selectedWorkerId]) {
+ selectedWorkerId = null;
+ document.getElementById('worker-detail').style.display = 'none';
+ }
+
+ container.style.display = 'block';
+ }
+
+ // Windows FILETIME: 100ns ticks since 1601-01-01. Convert to JS Date.
+ const FILETIME_EPOCH_OFFSET_MS = 11644473600000n;
+ function filetimeToDate(ticks) {
+ if (!ticks) return null;
+ const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS;
+ return new Date(Number(ms));
+ }
+
+ function formatTime(date) {
+ if (!date) return '-';
+ return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' });
+ }
+
+ function formatDuration(startDate, endDate) {
+ if (!startDate || !endDate) return '-';
+ const ms = endDate - startDate;
+ if (ms < 0) return '-';
+ if (ms < 1000) return ms + ' ms';
+ if (ms < 60000) return (ms / 1000).toFixed(2) + ' s';
+ const m = Math.floor(ms / 60000);
+ const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, '0');
+ return `${m}m ${s}s`;
+ }
+
+ async function fetchQueues() {
+ const data = await fetchJSON('/compute/queues');
+ const queues = data.queues || [];
+
+ const empty = document.getElementById('queue-list-empty');
+ const container = document.getElementById('queue-list-container');
+ const tbody = document.getElementById('queue-list-body');
+
+ if (queues.length === 0) {
+ empty.style.display = '';
+ container.style.display = 'none';
+ return;
+ }
+
+ empty.style.display = 'none';
+ tbody.innerHTML = '';
+
+ for (const q of queues) {
+ const id = q.queue_id ?? '-';
+ const badge = q.state === 'cancelled'
+ ? '<span class="status-badge failure">cancelled</span>'
+ : q.state === 'draining'
+ ? '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_warn) 15%, transparent);color:var(--theme_warn);">draining</span>'
+ : q.is_complete
+ ? '<span class="status-badge success">complete</span>'
+ : '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_p0) 15%, transparent);color:var(--theme_p0);">active</span>';
+ const token = q.queue_token
+ ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>`
+ : '<span style="color:var(--theme_faint);">-</span>';
+
+ const tr = document.createElement('tr');
+ tr.innerHTML = `
+ <td style="text-align: right; font-family: monospace; color: var(--theme_bright);">${escapeHtml(String(id))}</td>
+ <td style="text-align: center;">${badge}</td>
+ <td style="text-align: right;">${q.active_count ?? 0}</td>
+ <td style="text-align: right; color: var(--theme_ok);">${q.completed_count ?? 0}</td>
+ <td style="text-align: right; color: var(--theme_fail);">${q.failed_count ?? 0}</td>
+ <td style="text-align: right; color: var(--theme_warn);">${q.abandoned_count ?? 0}</td>
+ <td style="text-align: right; color: var(--theme_warn);">${q.cancelled_count ?? 0}</td>
+ <td>${token}</td>
+ `;
+ tbody.appendChild(tr);
+ }
+
+ container.style.display = 'block';
+ }
+
+ async function fetchActionHistory() {
+ const data = await fetchJSON('/compute/jobs/history?limit=50');
+ const entries = data.history || [];
+
+ const empty = document.getElementById('action-history-empty');
+ const container = document.getElementById('action-history-container');
+ const tbody = document.getElementById('action-history-body');
+
+ if (entries.length === 0) {
+ empty.style.display = '';
+ container.style.display = 'none';
+ return;
+ }
+
+ empty.style.display = 'none';
+ tbody.innerHTML = '';
+
+ // Entries arrive oldest-first; reverse to show newest at top
+ for (const entry of [...entries].reverse()) {
+ const lsn = entry.lsn ?? '-';
+ const succeeded = entry.succeeded;
+ const badge = succeeded == null
+ ? '<span class="status-badge" style="background:var(--theme_border_subtle);color:var(--theme_g1);">unknown</span>'
+ : succeeded
+ ? '<span class="status-badge success">ok</span>'
+ : '<span class="status-badge failure">failed</span>';
+ const desc = entry.actionDescriptor || {};
+ const fn = desc.Function || '-';
+ const workerId = entry.workerId || '-';
+ const actionId = entry.actionId || '-';
+
+ const startDate = filetimeToDate(entry.time_Running);
+ const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed);
+
+ const queueId = entry.queueId || 0;
+ const queueCell = queueId
+ ? `<a href="/compute/queues/${queueId}" style="color: var(--theme_ln); text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>`
+ : '<span style="color: var(--theme_faint);">-</span>';
+
+ const tr = document.createElement('tr');
+ tr.innerHTML = `
+ <td style="text-align: right; font-family: monospace; color: var(--theme_g1);">${escapeHtml(String(lsn))}</td>
+ <td style="text-align: right;">${queueCell}</td>
+ <td style="text-align: center;">${badge}</td>
+ <td style="color: var(--theme_bright);">${escapeHtml(fn)}</td>
+ <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(startDate)}</td>
+ <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(endDate)}</td>
+ <td style="text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td>
+ <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(workerId)}</td>
+ <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(actionId)}</td>
+ `;
+ tbody.appendChild(tr);
+ }
+
+ container.style.display = 'block';
+ }
+
+ async function updateDashboard() {
+ try {
+ await Promise.all([
+ fetchHealth(),
+ fetchStats(),
+ fetchSysInfo(),
+ fetchWorkers(),
+ fetchQueues(),
+ fetchActionHistory()
+ ]);
+
+ clearError();
+ updateTimestamp();
+ } catch (error) {
+ console.error('Error updating dashboard:', error);
+ showError(error.message);
+ }
+ }
+
+ // Start updating
+ updateDashboard();
+ setInterval(updateDashboard, REFRESH_INTERVAL);
+ </script>
+</body>
+</html>
diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html
new file mode 100644
index 000000000..32e1b05db
--- /dev/null
+++ b/src/zenserver/frontend/html/compute/hub.html
@@ -0,0 +1,170 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <link rel="stylesheet" type="text/css" href="../zen.css" />
+ <script src="../theme.js"></script>
+ <script src="../banner.js" defer></script>
+ <script src="../nav.js" defer></script>
+ <title>Zen Hub Dashboard</title>
+</head>
+<body>
+ <div class="container" style="max-width: 1400px; margin: 0 auto;">
+ <zen-banner cluster-status="nominal" subtitle="HUB" tagline="Overview" logo-src="../favicon.ico"></zen-banner>
+ <zen-nav>
+ <a href="/dashboard/">Home</a>
+ <a href="hub.html">Hub</a>
+ </zen-nav>
+ <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
+
+ <div id="error-container"></div>
+
+ <div class="section-title">Capacity</div>
+ <div class="grid">
+ <div class="card">
+ <div class="card-title">Active Modules</div>
+ <div class="metric-value" id="instance-count">-</div>
+ <div class="metric-label">Currently provisioned</div>
+ </div>
+ <div class="card">
+ <div class="card-title">Peak Modules</div>
+ <div class="metric-value" id="max-instance-count">-</div>
+ <div class="metric-label">High watermark</div>
+ </div>
+ <div class="card">
+ <div class="card-title">Instance Limit</div>
+ <div class="metric-value" id="instance-limit">-</div>
+ <div class="metric-label">Maximum allowed</div>
+ <div class="progress-bar">
+ <div class="progress-fill" id="capacity-progress" style="width: 0%"></div>
+ </div>
+ </div>
+ </div>
+
+ <div class="section-title">Modules</div>
+ <div class="card">
+ <div class="card-title">Storage Server Instances</div>
+ <div id="empty-state" class="empty-state">No modules provisioned.</div>
+ <table id="module-table" style="display: none;">
+ <thead>
+ <tr>
+ <th>Module ID</th>
+ <th style="text-align: center;">Status</th>
+ </tr>
+ </thead>
+ <tbody id="module-table-body"></tbody>
+ </table>
+ </div>
+ </div>
+
+ <script>
+ const BASE_URL = window.location.origin;
+ const REFRESH_INTERVAL = 2000;
+
+ function escapeHtml(text) {
+ var div = document.createElement('div');
+ div.textContent = text;
+ return div.innerHTML;
+ }
+
+ function showError(message) {
+ document.getElementById('error-container').innerHTML =
+ '<div class="error">Error: ' + escapeHtml(message) + '</div>';
+ }
+
+ function clearError() {
+ document.getElementById('error-container').innerHTML = '';
+ }
+
+ async function fetchJSON(endpoint) {
+ var response = await fetch(BASE_URL + endpoint, {
+ headers: { 'Accept': 'application/json' }
+ });
+ if (!response.ok) {
+ throw new Error('HTTP ' + response.status + ': ' + response.statusText);
+ }
+ return await response.json();
+ }
+
+ async function fetchStats() {
+ var data = await fetchJSON('/hub/stats');
+
+ var current = data.currentInstanceCount || 0;
+ var max = data.maxInstanceCount || 0;
+ var limit = data.instanceLimit || 0;
+
+ document.getElementById('instance-count').textContent = current;
+ document.getElementById('max-instance-count').textContent = max;
+ document.getElementById('instance-limit').textContent = limit;
+
+ var pct = limit > 0 ? (current / limit) * 100 : 0;
+ document.getElementById('capacity-progress').style.width = pct + '%';
+
+ var banner = document.querySelector('zen-banner');
+ if (current === 0) {
+ banner.setAttribute('cluster-status', 'nominal');
+ } else if (limit > 0 && current >= limit * 0.9) {
+ banner.setAttribute('cluster-status', 'degraded');
+ } else {
+ banner.setAttribute('cluster-status', 'nominal');
+ }
+ }
+
+ async function fetchModules() {
+ var data = await fetchJSON('/hub/status');
+ var modules = data.modules || [];
+
+ var emptyState = document.getElementById('empty-state');
+ var table = document.getElementById('module-table');
+ var tbody = document.getElementById('module-table-body');
+
+ if (modules.length === 0) {
+ emptyState.style.display = '';
+ table.style.display = 'none';
+ return;
+ }
+
+ emptyState.style.display = 'none';
+ table.style.display = '';
+
+ tbody.innerHTML = '';
+ for (var i = 0; i < modules.length; i++) {
+ var m = modules[i];
+ var moduleId = m.moduleId || '';
+ var provisioned = m.provisioned;
+
+ var badge = provisioned
+ ? '<span class="status-badge active">Provisioned</span>'
+ : '<span class="status-badge inactive">Inactive</span>';
+
+ var tr = document.createElement('tr');
+ tr.innerHTML =
+ '<td style="font-family: monospace; font-size: 12px;">' + escapeHtml(moduleId) + '</td>' +
+ '<td style="text-align: center;">' + badge + '</td>';
+ tbody.appendChild(tr);
+ }
+ }
+
+ async function updateDashboard() {
+ var banner = document.querySelector('zen-banner');
+ try {
+ await Promise.all([
+ fetchStats(),
+ fetchModules()
+ ]);
+
+ clearError();
+ document.getElementById('last-update').textContent = new Date().toLocaleTimeString();
+ } catch (error) {
+ console.error('Error updating dashboard:', error);
+ showError(error.message);
+ banner.setAttribute('cluster-status', 'offline');
+ }
+ }
+
+ updateDashboard();
+ setInterval(updateDashboard, REFRESH_INTERVAL);
+ </script>
+</body>
+</html>
diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html
new file mode 100644
index 000000000..9597fd7f3
--- /dev/null
+++ b/src/zenserver/frontend/html/compute/index.html
@@ -0,0 +1 @@
+<meta http-equiv="refresh" content="0; url=compute.html" /> \ No newline at end of file
diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html
new file mode 100644
index 000000000..a519dee18
--- /dev/null
+++ b/src/zenserver/frontend/html/compute/orchestrator.html
@@ -0,0 +1,674 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <link rel="stylesheet" type="text/css" href="../zen.css" />
+ <script src="../theme.js"></script>
+ <script src="../banner.js" defer></script>
+ <script src="../nav.js" defer></script>
+ <title>Zen Orchestrator Dashboard</title>
+ <style>
+ .agent-count {
+ display: flex;
+ align-items: center;
+ gap: 8px;
+ font-size: 14px;
+ padding: 8px 16px;
+ border-radius: 6px;
+ background: var(--theme_g3);
+ border: 1px solid var(--theme_g2);
+ }
+
+ .agent-count .count {
+ font-size: 20px;
+ font-weight: 600;
+ color: var(--theme_bright);
+ }
+ </style>
+</head>
+<body>
+ <div class="container" style="max-width: 1400px; margin: 0 auto;">
+ <zen-banner cluster-status="nominal" load="0" logo-src="../favicon.ico"></zen-banner>
+ <zen-nav>
+ <a href="/dashboard/">Home</a>
+ <a href="compute.html">Node</a>
+ <a href="orchestrator.html">Orchestrator</a>
+ </zen-nav>
+ <div class="header">
+ <div>
+ <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
+ </div>
+ <div class="agent-count">
+ <span>Agents:</span>
+ <span class="count" id="agent-count">-</span>
+ </div>
+ </div>
+
+ <div id="error-container"></div>
+
+ <div class="card">
+ <div class="card-title">Compute Agents</div>
+ <div id="empty-state" class="empty-state">No agents registered.</div>
+ <table id="agent-table" style="display: none;">
+ <thead>
+ <tr>
+ <th style="width: 40px; text-align: center;">Health</th>
+ <th>Hostname</th>
+ <th style="text-align: right;">CPUs</th>
+ <th style="text-align: right;">CPU Usage</th>
+ <th style="text-align: right;">Memory</th>
+ <th style="text-align: right;">Queues</th>
+ <th style="text-align: right;">Pending</th>
+ <th style="text-align: right;">Running</th>
+ <th style="text-align: right;">Completed</th>
+ <th style="text-align: right;">Traffic</th>
+ <th style="text-align: right;">Last Seen</th>
+ </tr>
+ </thead>
+ <tbody id="agent-table-body"></tbody>
+ </table>
+ </div>
+ <div class="card" style="margin-top: 20px;">
+ <div class="card-title">Connected Clients</div>
+ <div id="clients-empty" class="empty-state">No clients connected.</div>
+ <table id="clients-table" style="display: none;">
+ <thead>
+ <tr>
+ <th style="width: 40px; text-align: center;">Health</th>
+ <th>Client ID</th>
+ <th>Hostname</th>
+ <th>Address</th>
+ <th style="text-align: right;">Last Seen</th>
+ </tr>
+ </thead>
+ <tbody id="clients-table-body"></tbody>
+ </table>
+ </div>
+ <div class="card" style="margin-top: 20px;">
+ <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;">
+ <div class="card-title" style="margin-bottom: 0;">Event History</div>
+ <div class="history-tabs">
+ <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button>
+ <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button>
+ </div>
+ </div>
+ <div id="history-panel-workers">
+ <div id="history-empty" class="empty-state">No provisioning events recorded.</div>
+ <table id="history-table" style="display: none;">
+ <thead>
+ <tr>
+ <th>Time</th>
+ <th>Event</th>
+ <th>Worker</th>
+ <th>Hostname</th>
+ </tr>
+ </thead>
+ <tbody id="history-table-body"></tbody>
+ </table>
+ </div>
+ <div id="history-panel-clients" style="display: none;">
+ <div id="client-history-empty" class="empty-state">No client events recorded.</div>
+ <table id="client-history-table" style="display: none;">
+ <thead>
+ <tr>
+ <th>Time</th>
+ <th>Event</th>
+ <th>Client</th>
+ <th>Hostname</th>
+ </tr>
+ </thead>
+ <tbody id="client-history-table-body"></tbody>
+ </table>
+ </div>
+ </div>
+ </div>
+
+ <script>
+ const BASE_URL = window.location.origin;
+ const REFRESH_INTERVAL = 2000;
+
+ function escapeHtml(text) {
+ var div = document.createElement('div');
+ div.textContent = text;
+ return div.innerHTML;
+ }
+
+ function showError(message) {
+ document.getElementById('error-container').innerHTML =
+ '<div class="error">Error: ' + escapeHtml(message) + '</div>';
+ }
+
+ function clearError() {
+ document.getElementById('error-container').innerHTML = '';
+ }
+
+ function formatLastSeen(dtMs) {
+ if (dtMs == null) return '-';
+ var seconds = Math.floor(dtMs / 1000);
+ if (seconds < 60) return seconds + 's ago';
+ var minutes = Math.floor(seconds / 60);
+ if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago';
+ var hours = Math.floor(minutes / 60);
+ return hours + 'h ' + (minutes % 60) + 'm ago';
+ }
+
+ function healthClass(dtMs, reachable) {
+ if (reachable === false) return 'health-red';
+ if (dtMs == null) return 'health-red';
+ var seconds = dtMs / 1000;
+ if (seconds < 30 && reachable === true) return 'health-green';
+ if (seconds < 120) return 'health-yellow';
+ return 'health-red';
+ }
+
+ function healthTitle(dtMs, reachable) {
+ var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen';
+ if (reachable === true) return seenStr + ' · Reachable';
+ if (reachable === false) return seenStr + ' · Unreachable';
+ return seenStr + ' · Reachability unknown';
+ }
+
+ function formatCpuUsage(percent) {
+ if (percent == null || percent === 0) return '-';
+ return percent.toFixed(1) + '%';
+ }
+
+ function formatMemory(usedBytes, totalBytes) {
+ if (!totalBytes) return '-';
+ var usedGiB = usedBytes / (1024 * 1024 * 1024);
+ var totalGiB = totalBytes / (1024 * 1024 * 1024);
+ return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB';
+ }
+
+ function formatBytes(bytes) {
+ if (!bytes) return '-';
+ if (bytes < 1024) return bytes + ' B';
+ if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB';
+ if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB';
+ if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB';
+ return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB';
+ }
+
+ function formatTraffic(recv, sent) {
+ if (!recv && !sent) return '-';
+ return formatBytes(recv) + ' / ' + formatBytes(sent);
+ }
+
+ function parseIpFromUri(uri) {
+ try {
+ var url = new URL(uri);
+ var host = url.hostname;
+ // Strip IPv6 brackets
+ if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1);
+ // Only handle IPv4
+ var parts = host.split('.');
+ if (parts.length !== 4) return null;
+ var octets = parts.map(Number);
+ if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null;
+ return octets;
+ } catch (e) {
+ return null;
+ }
+ }
+
+ function computeCidr(ips) {
+ if (ips.length === 0) return null;
+ if (ips.length === 1) return ips[0].join('.') + '/32';
+
+ // Convert each IP to a 32-bit integer
+ var ints = ips.map(function(o) {
+ return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0;
+ });
+
+ // Find common prefix length by ANDing all identical high bits
+ var common = ~0 >>> 0;
+ for (var i = 1; i < ints.length; i++) {
+ // XOR to find differing bits, then mask away everything from the first difference down
+ var diff = (ints[0] ^ ints[i]) >>> 0;
+ if (diff !== 0) {
+ var bit = 31 - Math.floor(Math.log2(diff));
+ var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0;
+ common = (common & mask) >>> 0;
+ }
+ }
+
+ // Count leading ones in the common mask
+ var prefix = 0;
+ for (var b = 31; b >= 0; b--) {
+ if ((common >>> b) & 1) prefix++;
+ else break;
+ }
+
+ // Network address
+ var net = (ints[0] & common) >>> 0;
+ var a = (net >>> 24) & 0xff;
+ var bv = (net >>> 16) & 0xff;
+ var c = (net >>> 8) & 0xff;
+ var d = net & 0xff;
+ return a + '.' + bv + '.' + c + '.' + d + '/' + prefix;
+ }
+
+ function renderDashboard(data) {
+ var banner = document.querySelector('zen-banner');
+ if (data.hostname) {
+ banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname);
+ }
+ var workers = data.workers || [];
+
+ document.getElementById('agent-count').textContent = workers.length;
+
+ if (workers.length === 0) {
+ banner.setAttribute('cluster-status', 'degraded');
+ banner.setAttribute('load', '0');
+ } else {
+ banner.setAttribute('cluster-status', 'nominal');
+ }
+
+ var emptyState = document.getElementById('empty-state');
+ var table = document.getElementById('agent-table');
+ var tbody = document.getElementById('agent-table-body');
+
+ if (workers.length === 0) {
+ emptyState.style.display = '';
+ table.style.display = 'none';
+ } else {
+ emptyState.style.display = 'none';
+ table.style.display = '';
+
+ tbody.innerHTML = '';
+ var totalCpus = 0;
+ var totalWeightedCpuUsage = 0;
+ var totalMemUsed = 0;
+ var totalMemTotal = 0;
+ var totalQueues = 0;
+ var totalPending = 0;
+ var totalRunning = 0;
+ var totalCompleted = 0;
+ var totalBytesRecv = 0;
+ var totalBytesSent = 0;
+ var allIps = [];
+ for (var i = 0; i < workers.length; i++) {
+ var w = workers[i];
+ var uri = w.uri || '';
+ var dt = w.dt;
+ var dashboardUrl = uri + '/dashboard/compute/';
+
+ var id = w.id || '';
+
+ var hostname = w.hostname || '';
+ var cpus = w.cpus || 0;
+ totalCpus += cpus;
+ if (cpus > 0 && typeof w.cpu_usage === 'number') {
+ totalWeightedCpuUsage += w.cpu_usage * cpus;
+ }
+
+ var memTotal = w.memory_total || 0;
+ var memUsed = w.memory_used || 0;
+ totalMemTotal += memTotal;
+ totalMemUsed += memUsed;
+
+ var activeQueues = w.active_queues || 0;
+ totalQueues += activeQueues;
+
+ var actionsPending = w.actions_pending || 0;
+ var actionsRunning = w.actions_running || 0;
+ var actionsCompleted = w.actions_completed || 0;
+ totalPending += actionsPending;
+ totalRunning += actionsRunning;
+ totalCompleted += actionsCompleted;
+
+ var bytesRecv = w.bytes_received || 0;
+ var bytesSent = w.bytes_sent || 0;
+ totalBytesRecv += bytesRecv;
+ totalBytesSent += bytesSent;
+
+ var ip = parseIpFromUri(uri);
+ if (ip) allIps.push(ip);
+
+ var reachable = w.reachable;
+ var hClass = healthClass(dt, reachable);
+ var hTitle = healthTitle(dt, reachable);
+
+ var platform = w.platform || '';
+ var badges = '';
+ if (platform) {
+ var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' };
+ var platColor = platColors[platform] || '#8b949e';
+ badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>';
+ }
+ var provisioner = w.provisioner || '';
+ if (provisioner) {
+ var provColors = { horde: '#8957e5', nomad: '#3fb950' };
+ var provColor = provColors[provisioner] || '#8b949e';
+ badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>';
+ }
+
+ var tr = document.createElement('tr');
+ tr.title = id;
+ tr.innerHTML =
+ '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' +
+ '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' +
+ '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' +
+ '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' +
+ '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' +
+ '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' +
+ '<td style="text-align: right;">' + actionsPending + '</td>' +
+ '<td style="text-align: right;">' + actionsRunning + '</td>' +
+ '<td style="text-align: right;">' + actionsCompleted + '</td>' +
+ '<td style="text-align: right; font-size: 11px; color: var(--theme_g1);">' + formatTraffic(bytesRecv, bytesSent) + '</td>' +
+ '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>';
+ tbody.appendChild(tr);
+ }
+
+ var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0;
+ banner.setAttribute('load', clusterLoad.toFixed(1));
+
+ // Total row
+ var cidr = computeCidr(allIps);
+ var totalTr = document.createElement('tr');
+ totalTr.className = 'total-row';
+ totalTr.innerHTML =
+ '<td></td>' +
+ '<td style="text-align: right; color: var(--theme_g1); text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' +
+ '<td style="text-align: right;">' + totalCpus + '</td>' +
+ '<td></td>' +
+ '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' +
+ '<td style="text-align: right;">' + totalQueues + '</td>' +
+ '<td style="text-align: right;">' + totalPending + '</td>' +
+ '<td style="text-align: right;">' + totalRunning + '</td>' +
+ '<td style="text-align: right;">' + totalCompleted + '</td>' +
+ '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' +
+ '<td></td>';
+ tbody.appendChild(totalTr);
+ }
+
+ clearError();
+ document.getElementById('last-update').textContent = new Date().toLocaleTimeString();
+
+ // Render provisioning history if present in WebSocket payload
+ if (data.events) {
+ renderProvisioningHistory(data.events);
+ }
+
+ // Render connected clients if present
+ if (data.clients) {
+ renderClients(data.clients);
+ }
+
+ // Render client history if present
+ if (data.client_events) {
+ renderClientHistory(data.client_events);
+ }
+ }
+
+ function eventBadge(type) {
+ var colors = { joined: 'var(--theme_ok)', left: 'var(--theme_fail)', returned: 'var(--theme_warn)' };
+ var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' };
+ var color = colors[type] || 'var(--theme_g1)';
+ var label = labels[type] || type;
+ return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>';
+ }
+
+ function formatTimestamp(ts) {
+ if (!ts) return '-';
+ // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string
+ var date;
+ if (typeof ts === 'number') {
+ // .NET-style ticks: convert to Unix ms
+ var unixMs = (ts - 621355968000000000) / 10000;
+ date = new Date(unixMs);
+ } else {
+ date = new Date(ts);
+ }
+ if (isNaN(date.getTime())) return '-';
+ return date.toLocaleTimeString();
+ }
+
+ var activeHistoryTab = 'workers';
+
+ function switchHistoryTab(tab) {
+ activeHistoryTab = tab;
+ var tabs = document.querySelectorAll('.history-tab');
+ for (var i = 0; i < tabs.length; i++) {
+ tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab);
+ }
+ document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none';
+ document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none';
+ }
+
+ function renderProvisioningHistory(events) {
+ var emptyState = document.getElementById('history-empty');
+ var table = document.getElementById('history-table');
+ var tbody = document.getElementById('history-table-body');
+
+ if (!events || events.length === 0) {
+ emptyState.style.display = '';
+ table.style.display = 'none';
+ return;
+ }
+
+ emptyState.style.display = 'none';
+ table.style.display = '';
+ tbody.innerHTML = '';
+
+ for (var i = 0; i < events.length; i++) {
+ var evt = events[i];
+ var tr = document.createElement('tr');
+ tr.innerHTML =
+ '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' +
+ '<td>' + eventBadge(evt.type) + '</td>' +
+ '<td>' + escapeHtml(evt.worker_id || '') + '</td>' +
+ '<td>' + escapeHtml(evt.hostname || '') + '</td>';
+ tbody.appendChild(tr);
+ }
+ }
+
+ function clientHealthClass(dtMs) {
+ if (dtMs == null) return 'health-red';
+ var seconds = dtMs / 1000;
+ if (seconds < 30) return 'health-green';
+ if (seconds < 120) return 'health-yellow';
+ return 'health-red';
+ }
+
+ function renderClients(clients) {
+ var emptyState = document.getElementById('clients-empty');
+ var table = document.getElementById('clients-table');
+ var tbody = document.getElementById('clients-table-body');
+
+ if (!clients || clients.length === 0) {
+ emptyState.style.display = '';
+ table.style.display = 'none';
+ return;
+ }
+
+ emptyState.style.display = 'none';
+ table.style.display = '';
+ tbody.innerHTML = '';
+
+ for (var i = 0; i < clients.length; i++) {
+ var c = clients[i];
+ var dt = c.dt;
+ var hClass = clientHealthClass(dt);
+ var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen';
+
+ var sessionBadge = '';
+ if (c.session_id) {
+ sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:var(--theme_faint);" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>';
+ }
+
+ var tr = document.createElement('tr');
+ tr.innerHTML =
+ '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' +
+ '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' +
+ '<td>' + escapeHtml(c.hostname || '') + '</td>' +
+ '<td style="font-family: monospace; font-size: 12px; color: var(--theme_g1);">' + escapeHtml(c.address || '') + '</td>' +
+ '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>';
+ tbody.appendChild(tr);
+ }
+ }
+
+ function clientEventBadge(type) {
+ var colors = { connected: 'var(--theme_ok)', disconnected: 'var(--theme_fail)', updated: 'var(--theme_warn)' };
+ var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' };
+ var color = colors[type] || 'var(--theme_g1)';
+ var label = labels[type] || type;
+ return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>';
+ }
+
+ function renderClientHistory(events) {
+ var emptyState = document.getElementById('client-history-empty');
+ var table = document.getElementById('client-history-table');
+ var tbody = document.getElementById('client-history-table-body');
+
+ if (!events || events.length === 0) {
+ emptyState.style.display = '';
+ table.style.display = 'none';
+ return;
+ }
+
+ emptyState.style.display = 'none';
+ table.style.display = '';
+ tbody.innerHTML = '';
+
+ for (var i = 0; i < events.length; i++) {
+ var evt = events[i];
+ var tr = document.createElement('tr');
+ tr.innerHTML =
+ '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' +
+ '<td>' + clientEventBadge(evt.type) + '</td>' +
+ '<td>' + escapeHtml(evt.client_id || '') + '</td>' +
+ '<td>' + escapeHtml(evt.hostname || '') + '</td>';
+ tbody.appendChild(tr);
+ }
+ }
+
+ // Fetch-based polling fallback
+ var pollTimer = null;
+
+ async function fetchProvisioningHistory() {
+ try {
+ var response = await fetch(BASE_URL + '/orch/history?limit=50', {
+ headers: { 'Accept': 'application/json' }
+ });
+ if (response.ok) {
+ var data = await response.json();
+ renderProvisioningHistory(data.events || []);
+ }
+ } catch (e) {
+ console.error('Error fetching provisioning history:', e);
+ }
+ }
+
+ async function fetchClients() {
+ try {
+ var response = await fetch(BASE_URL + '/orch/clients', {
+ headers: { 'Accept': 'application/json' }
+ });
+ if (response.ok) {
+ var data = await response.json();
+ renderClients(data.clients || []);
+ }
+ } catch (e) {
+ console.error('Error fetching clients:', e);
+ }
+ }
+
+ async function fetchClientHistory() {
+ try {
+ var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', {
+ headers: { 'Accept': 'application/json' }
+ });
+ if (response.ok) {
+ var data = await response.json();
+ renderClientHistory(data.client_events || []);
+ }
+ } catch (e) {
+ console.error('Error fetching client history:', e);
+ }
+ }
+
+ async function fetchDashboard() {
+ var banner = document.querySelector('zen-banner');
+ try {
+ var response = await fetch(BASE_URL + '/orch/agents', {
+ headers: { 'Accept': 'application/json' }
+ });
+
+ if (!response.ok) {
+ banner.setAttribute('cluster-status', 'degraded');
+ throw new Error('HTTP ' + response.status + ': ' + response.statusText);
+ }
+
+ renderDashboard(await response.json());
+ fetchProvisioningHistory();
+ fetchClients();
+ fetchClientHistory();
+ } catch (error) {
+ console.error('Error updating dashboard:', error);
+ showError(error.message);
+ banner.setAttribute('cluster-status', 'offline');
+ }
+ }
+
+ function startPolling() {
+ if (pollTimer) return;
+ fetchDashboard();
+ pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL);
+ }
+
+ function stopPolling() {
+ if (pollTimer) {
+ clearInterval(pollTimer);
+ pollTimer = null;
+ }
+ }
+
+ // WebSocket connection with automatic reconnect and polling fallback
+ var ws = null;
+
+ function connectWebSocket() {
+ var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
+ ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws');
+
+ ws.onopen = function() {
+ stopPolling();
+ clearError();
+ };
+
+ ws.onmessage = function(event) {
+ try {
+ renderDashboard(JSON.parse(event.data));
+ } catch (e) {
+ console.error('WebSocket message parse error:', e);
+ }
+ };
+
+ ws.onclose = function() {
+ ws = null;
+ startPolling();
+ setTimeout(connectWebSocket, 3000);
+ };
+
+ ws.onerror = function() {
+ // onclose will fire after onerror
+ };
+ }
+
+ // Fetch orchestrator hostname for the banner
+ fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } })
+ .then(function(r) { return r.ok ? r.json() : null; })
+ .then(function(d) {
+ if (d && d.hostname) {
+ document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname);
+ }
+ })
+ .catch(function() {});
+
+ // Initial load via fetch, then try WebSocket
+ fetchDashboard();
+ connectWebSocket();
+ </script>
+</body>
+</html>
diff --git a/src/UnrealEngine.ico b/src/zenserver/frontend/html/epicgames.ico
index 1cfa301a2..1cfa301a2 100644
--- a/src/UnrealEngine.ico
+++ b/src/zenserver/frontend/html/epicgames.ico
Binary files differ
diff --git a/src/zenserver/frontend/html/favicon.ico b/src/zenserver/frontend/html/favicon.ico
index 1cfa301a2..f7fb251b5 100644
--- a/src/zenserver/frontend/html/favicon.ico
+++ b/src/zenserver/frontend/html/favicon.ico
Binary files differ
diff --git a/src/zenserver/frontend/html/index.html b/src/zenserver/frontend/html/index.html
index 6a736e914..24a136a30 100644
--- a/src/zenserver/frontend/html/index.html
+++ b/src/zenserver/frontend/html/index.html
@@ -10,6 +10,9 @@
</script>
<link rel="shortcut icon" href="favicon.ico">
<link rel="stylesheet" type="text/css" href="zen.css" />
+ <script src="theme.js"></script>
+ <script src="banner.js" defer></script>
+ <script src="nav.js" defer></script>
<script type="module" src="zen.js"></script>
</head>
</html>
diff --git a/src/zenserver/frontend/html/nav.js b/src/zenserver/frontend/html/nav.js
new file mode 100644
index 000000000..a5de203f2
--- /dev/null
+++ b/src/zenserver/frontend/html/nav.js
@@ -0,0 +1,79 @@
+/**
+ * zen-nav.js — Zen dashboard navigation bar Web Component
+ *
+ * Usage:
+ * <script src="nav.js" defer></script>
+ *
+ * <zen-nav>
+ * <a href="compute.html">Node</a>
+ * <a href="orchestrator.html">Orchestrator</a>
+ * </zen-nav>
+ *
+ * Each child <a> becomes a nav link. The current page is
+ * highlighted automatically based on the href.
+ */
+
+class ZenNav extends HTMLElement {
+
+ connectedCallback() {
+ if (!this.shadowRoot) this.attachShadow({ mode: 'open' });
+ this._render();
+ }
+
+ _render() {
+ const currentPath = window.location.pathname;
+ const items = Array.from(this.querySelectorAll(':scope > a'));
+
+ const links = items.map(a => {
+ const href = a.getAttribute('href') || '';
+ const label = a.textContent.trim();
+ const active = currentPath.endsWith(href);
+ return `<a class="nav-link${active ? ' active' : ''}" href="${href}">${label}</a>`;
+ }).join('');
+
+ this.shadowRoot.innerHTML = `
+ <style>
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
+
+ :host {
+ display: block;
+ margin-bottom: 16px;
+ }
+
+ .nav-bar {
+ display: flex;
+ align-items: center;
+ gap: 4px;
+ padding: 4px;
+ background: var(--theme_g3);
+ border: 1px solid var(--theme_g2);
+ border-radius: 6px;
+ }
+
+ .nav-link {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
+ font-size: 13px;
+ font-weight: 500;
+ color: var(--theme_g1);
+ text-decoration: none;
+ padding: 6px 14px;
+ border-radius: 4px;
+ transition: color 0.15s, background 0.15s;
+ }
+
+ .nav-link:hover {
+ color: var(--theme_g0);
+ background: var(--theme_p4);
+ }
+
+ .nav-link.active {
+ color: var(--theme_bright);
+ background: var(--theme_g2);
+ }
+ </style>
+ <nav class="nav-bar">${links}</nav>
+ `;
+ }
+}
+
+customElements.define('zen-nav', ZenNav);
diff --git a/src/zenserver/frontend/html/pages/cache.js b/src/zenserver/frontend/html/pages/cache.js
new file mode 100644
index 000000000..3b838958a
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/cache.js
@@ -0,0 +1,690 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+import { Modal } from "../util/modal.js"
+import { Table, Toolbar } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("cache");
+
+ // Cache Service Stats
+ const stats_section = this._collapsible_section("Cache Service Stats");
+ stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => {
+ window.open("/stats/z$.yaml?cidstorestats=true&cachestorestats=true", "_blank");
+ });
+ this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
+ this._details_host = stats_section;
+ this._details_container = null;
+ this._selected_category = null;
+
+ const stats = await new Fetcher().resource("stats", "z$").json();
+ if (stats)
+ {
+ this._render_stats(stats);
+ }
+
+ this._connect_stats_ws();
+
+ // Cache Namespaces
+ var section = this._collapsible_section("Cache Namespaces");
+
+ section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all());
+
+ var columns = [
+ "namespace",
+ "dir",
+ "buckets",
+ "entries",
+ "size disk",
+ "size mem",
+ "actions",
+ ];
+
+ var zcache_info = await new Fetcher().resource("/z$/").json();
+ this._cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_AlignNumeric);
+
+ for (const namespace of zcache_info["Namespaces"] || [])
+ {
+ new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => {
+ const row = this._cache_table.add_row(
+ "",
+ data["Configuration"]["RootDir"],
+ data["Buckets"].length,
+ data["EntryCount"],
+ Friendly.bytes(data["StorageSize"].DiskSize),
+ Friendly.bytes(data["StorageSize"].MemorySize)
+ );
+ var cell = row.get_cell(0);
+ cell.tag().text(namespace).on_click(() => this.view_namespace(namespace));
+
+ cell = row.get_cell(-1);
+ const action_tb = new Toolbar(cell, true);
+ action_tb.left().add("view").on_click(() => this.view_namespace(namespace));
+ action_tb.left().add("drop").on_click(() => this.drop_namespace(namespace));
+
+ row.attr("zs_name", namespace);
+ });
+ }
+
+ // Namespace detail area (inside namespaces section so it collapses together)
+ this._namespace_host = section;
+ this._namespace_container = null;
+ this._selected_namespace = null;
+
+ // Restore namespace from URL if present
+ const ns_param = this.get_param("namespace");
+ if (ns_param)
+ {
+ this.view_namespace(ns_param);
+ }
+ }
+
+ _collapsible_section(name)
+ {
+ const section = this.add_section(name);
+ const container = section._parent.inner();
+ const heading = container.firstElementChild;
+
+ heading.style.cursor = "pointer";
+ heading.style.userSelect = "none";
+
+ const indicator = document.createElement("span");
+ indicator.textContent = " \u25BC";
+ indicator.style.fontSize = "0.7em";
+ heading.appendChild(indicator);
+
+ let collapsed = false;
+ heading.addEventListener("click", (e) => {
+ if (e.target !== heading && e.target !== indicator)
+ {
+ return;
+ }
+ collapsed = !collapsed;
+ indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
+ let sibling = heading.nextElementSibling;
+ while (sibling)
+ {
+ sibling.style.display = collapsed ? "none" : "";
+ sibling = sibling.nextElementSibling;
+ }
+ });
+
+ return section;
+ }
+
+ _connect_stats_ws()
+ {
+ try
+ {
+ const proto = location.protocol === "https:" ? "wss:" : "ws:";
+ const ws = new WebSocket(`${proto}//${location.host}/stats`);
+
+ try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) { this._ws_paused = false; }
+ document.addEventListener("zen-ws-toggle", (e) => {
+ this._ws_paused = e.detail.paused;
+ });
+
+ ws.onmessage = (ev) => {
+ if (this._ws_paused)
+ {
+ return;
+ }
+ try
+ {
+ const all_stats = JSON.parse(ev.data);
+ const stats = all_stats["z$"];
+ if (stats)
+ {
+ this._render_stats(stats);
+ }
+ }
+ catch (e) { /* ignore parse errors */ }
+ };
+
+ ws.onclose = () => { this._stats_ws = null; };
+ ws.onerror = () => { ws.close(); };
+
+ this._stats_ws = ws;
+ }
+ catch (e) { /* WebSocket not available */ }
+ }
+
+ _render_stats(stats)
+ {
+ const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
+ const grid = this._stats_grid;
+
+ this._last_stats = stats;
+ grid.inner().innerHTML = "";
+
+ // Store I/O tile
+ {
+ const store = safe(stats, "cache.store");
+ if (store)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed");
+ if (this._selected_category === "store") tile.classify("stats-tile-selected");
+ tile.on_click(() => this._select_category("store"));
+ tile.tag().classify("card-title").text("Store I/O");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const storeHits = store.hits || 0;
+ const storeMisses = store.misses || 0;
+ const storeTotal = storeHits + storeMisses;
+ const storeRatio = storeTotal > 0 ? ((storeHits / storeTotal) * 100).toFixed(1) + "%" : "-";
+ this._metric(left, storeRatio, "store hit ratio", true);
+ this._metric(left, Friendly.sep(storeHits), "hits");
+ this._metric(left, Friendly.sep(storeMisses), "misses");
+ this._metric(left, Friendly.sep(store.writes || 0), "writes");
+ this._metric(left, Friendly.sep(store.rejected_reads || 0), "rejected reads");
+ this._metric(left, Friendly.sep(store.rejected_writes || 0), "rejected writes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const readRateMean = safe(store, "read.bytes.rate_mean") || 0;
+ const readRate1 = safe(store, "read.bytes.rate_1") || 0;
+ const readRate5 = safe(store, "read.bytes.rate_5") || 0;
+ const writeRateMean = safe(store, "write.bytes.rate_mean") || 0;
+ const writeRate1 = safe(store, "write.bytes.rate_1") || 0;
+ const writeRate5 = safe(store, "write.bytes.rate_5") || 0;
+ this._metric(right, Friendly.bytes(readRateMean) + "/s", "read rate (mean)", true);
+ this._metric(right, Friendly.bytes(readRate1) + "/s", "read rate (1m)");
+ this._metric(right, Friendly.bytes(readRate5) + "/s", "read rate (5m)");
+ this._metric(right, Friendly.bytes(writeRateMean) + "/s", "write rate (mean)");
+ this._metric(right, Friendly.bytes(writeRate1) + "/s", "write rate (1m)");
+ this._metric(right, Friendly.bytes(writeRate5) + "/s", "write rate (5m)");
+ }
+ }
+
+ // Hit/Miss tile
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Hit Ratio");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const hits = safe(stats, "cache.hits") || 0;
+ const misses = safe(stats, "cache.misses") || 0;
+ const writes = safe(stats, "cache.writes") || 0;
+ const total = hits + misses;
+ const ratio = total > 0 ? ((hits / total) * 100).toFixed(1) + "%" : "-";
+
+ this._metric(left, ratio, "hit ratio", true);
+ this._metric(left, Friendly.sep(hits), "hits");
+ this._metric(left, Friendly.sep(misses), "misses");
+ this._metric(left, Friendly.sep(writes), "writes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const cidHits = safe(stats, "cache.cidhits") || 0;
+ const cidMisses = safe(stats, "cache.cidmisses") || 0;
+ const cidWrites = safe(stats, "cache.cidwrites") || 0;
+ const cidTotal = cidHits + cidMisses;
+ const cidRatio = cidTotal > 0 ? ((cidHits / cidTotal) * 100).toFixed(1) + "%" : "-";
+
+ this._metric(right, cidRatio, "cid hit ratio", true);
+ this._metric(right, Friendly.sep(cidHits), "cid hits");
+ this._metric(right, Friendly.sep(cidMisses), "cid misses");
+ this._metric(right, Friendly.sep(cidWrites), "cid writes");
+ }
+
+ // HTTP Requests tile
+ {
+ const req = safe(stats, "requests");
+ if (req)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("HTTP Requests");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const reqData = req.requests || req;
+ this._metric(left, Friendly.sep(reqData.count || 0), "total requests", true);
+ if (reqData.rate_mean > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)");
+ }
+ if (reqData.rate_1 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)");
+ }
+ if (reqData.rate_5 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_5, 1) + "/s", "req/sec (5m)");
+ }
+ if (reqData.rate_15 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_15, 1) + "/s", "req/sec (15m)");
+ }
+ const badRequests = safe(stats, "cache.badrequestcount") || 0;
+ this._metric(left, Friendly.sep(badRequests), "bad requests");
+
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true);
+ if (reqData.t_p75)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p75), "p75");
+ }
+ if (reqData.t_p95)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p95), "p95");
+ }
+ if (reqData.t_p99)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p99), "p99");
+ }
+ if (reqData.t_p999)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p999), "p999");
+ }
+ if (reqData.t_max)
+ {
+ this._metric(right, Friendly.duration(reqData.t_max), "max");
+ }
+ }
+ }
+
+ // RPC tile
+ {
+ const rpc = safe(stats, "cache.rpc");
+ if (rpc)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("RPC");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, Friendly.sep(rpc.count || 0), "rpc calls", true);
+ this._metric(left, Friendly.sep(rpc.ops || 0), "batch ops");
+
+ const right = columns.tag().classify("tile-metrics");
+ if (rpc.records)
+ {
+ this._metric(right, Friendly.sep(rpc.records.count || 0), "record calls");
+ this._metric(right, Friendly.sep(rpc.records.ops || 0), "record ops");
+ }
+ if (rpc.values)
+ {
+ this._metric(right, Friendly.sep(rpc.values.count || 0), "value calls");
+ this._metric(right, Friendly.sep(rpc.values.ops || 0), "value ops");
+ }
+ if (rpc.chunks)
+ {
+ this._metric(right, Friendly.sep(rpc.chunks.count || 0), "chunk calls");
+ this._metric(right, Friendly.sep(rpc.chunks.ops || 0), "chunk ops");
+ }
+ }
+ }
+
+ // Storage tile
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed");
+ if (this._selected_category === "storage") tile.classify("stats-tile-selected");
+ tile.on_click(() => this._select_category("storage"));
+ tile.tag().classify("card-title").text("Storage");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, safe(stats, "cache.size.disk") != null ? Friendly.bytes(safe(stats, "cache.size.disk")) : "-", "cache disk", true);
+ this._metric(left, safe(stats, "cache.size.memory") != null ? Friendly.bytes(safe(stats, "cache.size.memory")) : "-", "cache memory");
+
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, safe(stats, "cid.size.total") != null ? Friendly.bytes(safe(stats, "cid.size.total")) : "-", "cid total", true);
+ this._metric(right, safe(stats, "cid.size.tiny") != null ? Friendly.bytes(safe(stats, "cid.size.tiny")) : "-", "cid tiny");
+ this._metric(right, safe(stats, "cid.size.small") != null ? Friendly.bytes(safe(stats, "cid.size.small")) : "-", "cid small");
+ this._metric(right, safe(stats, "cid.size.large") != null ? Friendly.bytes(safe(stats, "cid.size.large")) : "-", "cid large");
+ }
+
+ // Upstream tile (only if upstream is active)
+ {
+ const upstream = safe(stats, "upstream");
+ if (upstream)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Upstream");
+ const body = tile.tag().classify("tile-metrics");
+
+ const upstreamHits = safe(stats, "cache.upstream_hits") || 0;
+ this._metric(body, Friendly.sep(upstreamHits), "upstream hits", true);
+
+ if (upstream.url)
+ {
+ this._metric(body, upstream.url, "endpoint");
+ }
+ }
+ }
+ }
+
+ _metric(parent, value, label, hero = false)
+ {
+ const m = parent.tag().classify("tile-metric");
+ if (hero)
+ {
+ m.classify("tile-metric-hero");
+ }
+ m.tag().classify("metric-value").text(value);
+ m.tag().classify("metric-label").text(label);
+ }
+
+ async _select_category(category)
+ {
+ // Toggle off if already selected
+ if (this._selected_category === category)
+ {
+ this._selected_category = null;
+ this._clear_details();
+ this._render_stats(this._last_stats);
+ return;
+ }
+
+ this._selected_category = category;
+ this._render_stats(this._last_stats);
+
+ // Fetch detailed stats
+ const detailed = await new Fetcher()
+ .resource("stats", "z$")
+ .param("cachestorestats", "true")
+ .param("cidstorestats", "true")
+ .json();
+
+ if (!detailed || this._selected_category !== category)
+ {
+ return;
+ }
+
+ this._clear_details();
+
+ const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
+
+ if (category === "store")
+ {
+ this._render_store_details(detailed, safe);
+ }
+ else if (category === "storage")
+ {
+ this._render_storage_details(detailed, safe);
+ }
+ }
+
+ _clear_details()
+ {
+ if (this._details_container)
+ {
+ this._details_container.inner().remove();
+ this._details_container = null;
+ }
+ }
+
+ _render_store_details(stats, safe)
+ {
+ const namespaces = safe(stats, "cache.store.namespaces") || [];
+ if (namespaces.length === 0)
+ {
+ return;
+ }
+
+ const container = this._details_host.tag();
+ this._details_container = container;
+
+ const columns = [
+ "namespace",
+ "bucket",
+ "hits",
+ "misses",
+ "writes",
+ "hit ratio",
+ "read count",
+ "read bandwidth",
+ "write count",
+ "write bandwidth",
+ ];
+ const table = new Table(container, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric);
+
+ for (const ns of namespaces)
+ {
+ const nsHits = ns.hits || 0;
+ const nsMisses = ns.misses || 0;
+ const nsTotal = nsHits + nsMisses;
+ const nsRatio = nsTotal > 0 ? ((nsHits / nsTotal) * 100).toFixed(1) + "%" : "-";
+
+ const readCount = safe(ns, "read.request.count") || 0;
+ const readBytes = safe(ns, "read.bytes.count") || 0;
+ const writeCount = safe(ns, "write.request.count") || 0;
+ const writeBytes = safe(ns, "write.bytes.count") || 0;
+
+ table.add_row(
+ ns.namespace,
+ "",
+ Friendly.sep(nsHits),
+ Friendly.sep(nsMisses),
+ Friendly.sep(ns.writes || 0),
+ nsRatio,
+ Friendly.sep(readCount),
+ Friendly.bytes(readBytes),
+ Friendly.sep(writeCount),
+ Friendly.bytes(writeBytes),
+ );
+
+ if (ns.buckets && ns.buckets.length > 0)
+ {
+ for (const bucket of ns.buckets)
+ {
+ const bHits = bucket.hits || 0;
+ const bMisses = bucket.misses || 0;
+ const bTotal = bHits + bMisses;
+ const bRatio = bTotal > 0 ? ((bHits / bTotal) * 100).toFixed(1) + "%" : "-";
+
+ const bReadCount = safe(bucket, "read.request.count") || 0;
+ const bReadBytes = safe(bucket, "read.bytes.count") || 0;
+ const bWriteCount = safe(bucket, "write.request.count") || 0;
+ const bWriteBytes = safe(bucket, "write.bytes.count") || 0;
+
+ table.add_row(
+ ns.namespace,
+ bucket.bucket,
+ Friendly.sep(bHits),
+ Friendly.sep(bMisses),
+ Friendly.sep(bucket.writes || 0),
+ bRatio,
+ Friendly.sep(bReadCount),
+ Friendly.bytes(bReadBytes),
+ Friendly.sep(bWriteCount),
+ Friendly.bytes(bWriteBytes),
+ );
+ }
+ }
+ }
+ }
+
+ _render_storage_details(stats, safe)
+ {
+ const namespaces = safe(stats, "cache.store.namespaces") || [];
+ if (namespaces.length === 0)
+ {
+ return;
+ }
+
+ const container = this._details_host.tag();
+ this._details_container = container;
+
+ const columns = [
+ "namespace",
+ "bucket",
+ "disk",
+ "memory",
+ ];
+ const table = new Table(container, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric);
+
+ for (const ns of namespaces)
+ {
+ const diskSize = safe(ns, "size.disk") || 0;
+ const memSize = safe(ns, "size.memory") || 0;
+
+ table.add_row(
+ ns.namespace,
+ "",
+ Friendly.bytes(diskSize),
+ Friendly.bytes(memSize),
+ );
+
+ if (ns.buckets && ns.buckets.length > 0)
+ {
+ for (const bucket of ns.buckets)
+ {
+ const bDisk = safe(bucket, "size.disk") || 0;
+ const bMem = safe(bucket, "size.memory") || 0;
+
+ table.add_row(
+ ns.namespace,
+ bucket.bucket,
+ Friendly.bytes(bDisk),
+ Friendly.bytes(bMem),
+ );
+ }
+ }
+ }
+ }
+
+ async view_namespace(namespace)
+ {
+ // Toggle off if already selected
+ if (this._selected_namespace === namespace)
+ {
+ this._selected_namespace = null;
+ this._clear_namespace();
+ this._clear_param("namespace");
+ return;
+ }
+
+ this._selected_namespace = namespace;
+ this._clear_namespace();
+ this.set_param("namespace", namespace);
+
+ const info = await new Fetcher().resource(`/z$/${namespace}/`).json();
+ if (this._selected_namespace !== namespace)
+ {
+ return;
+ }
+
+ const section = this._namespace_host.add_section(namespace);
+ this._namespace_container = section;
+
+ // Buckets table
+ const bucket_section = section.add_section("Buckets");
+ const bucket_columns = ["name", "disk", "memory", "entries", "actions"];
+ const bucket_table = bucket_section.add_widget(
+ Table,
+ bucket_columns,
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric
+ );
+
+ // Right-align header for numeric columns (skip # and name)
+ const header = bucket_table._element.firstElementChild;
+ for (let i = 2; i < header.children.length - 1; i++)
+ {
+ header.children[i].style.textAlign = "right";
+ }
+
+ let totalDisk = 0, totalMem = 0, totalEntries = 0;
+ const total_row = bucket_table.add_row("TOTAL");
+ total_row.get_cell(0).style("fontWeight", "bold");
+ total_row.get_cell(1).style("textAlign", "right").style("fontWeight", "bold");
+ total_row.get_cell(2).style("textAlign", "right").style("fontWeight", "bold");
+ total_row.get_cell(3).style("textAlign", "right").style("fontWeight", "bold");
+
+ for (const bucket of info["Buckets"])
+ {
+ const row = bucket_table.add_row(bucket);
+ new Fetcher().resource(`/z$/${namespace}/${bucket}`).json().then((data) => {
+ row.get_cell(1).text(Friendly.bytes(data["StorageSize"]["DiskSize"])).style("textAlign", "right");
+ row.get_cell(2).text(Friendly.bytes(data["StorageSize"]["MemorySize"])).style("textAlign", "right");
+ row.get_cell(3).text(Friendly.sep(data["DiskEntryCount"])).style("textAlign", "right");
+
+ const cell = row.get_cell(-1);
+ const action_tb = new Toolbar(cell, true);
+ action_tb.left().add("drop").on_click(() => this.drop_bucket(namespace, bucket));
+
+ totalDisk += data["StorageSize"]["DiskSize"];
+ totalMem += data["StorageSize"]["MemorySize"];
+ totalEntries += data["DiskEntryCount"];
+ total_row.get_cell(1).text(Friendly.bytes(totalDisk)).style("textAlign", "right").style("fontWeight", "bold");
+ total_row.get_cell(2).text(Friendly.bytes(totalMem)).style("textAlign", "right").style("fontWeight", "bold");
+ total_row.get_cell(3).text(Friendly.sep(totalEntries)).style("textAlign", "right").style("fontWeight", "bold");
+ });
+ }
+
+ }
+
+ _clear_param(name)
+ {
+ this._params.delete(name);
+ const url = new URL(window.location);
+ url.searchParams.delete(name);
+ history.replaceState(null, "", url);
+ }
+
+ _clear_namespace()
+ {
+ if (this._namespace_container)
+ {
+ this._namespace_container._parent.inner().remove();
+ this._namespace_container = null;
+ }
+ }
+
+ drop_bucket(namespace, bucket)
+ {
+ const drop = async () => {
+ await new Fetcher().resource("z$", namespace, bucket).delete();
+ // Refresh the namespace view
+ this._selected_namespace = null;
+ this._clear_namespace();
+ this.view_namespace(namespace);
+ };
+
+ new Modal()
+ .title("Confirmation")
+ .message(`Drop bucket '${bucket}'?`)
+ .option("Yes", () => drop())
+ .option("No");
+ }
+
+ drop_namespace(namespace)
+ {
+ const drop = async () => {
+ await new Fetcher().resource("z$", namespace).delete();
+ this.reload();
+ };
+
+ new Modal()
+ .title("Confirmation")
+ .message(`Drop cache namespace '${namespace}'?`)
+ .option("Yes", () => drop())
+ .option("No");
+ }
+
+ async drop_all()
+ {
+ const drop = async () => {
+ for (const row of this._cache_table)
+ {
+ const namespace = row.attr("zs_name");
+ await new Fetcher().resource("z$", namespace).delete();
+ }
+ this.reload();
+ };
+
+ new Modal()
+ .title("Confirmation")
+ .message("Drop every cache namespace?")
+ .option("Yes", () => drop())
+ .option("No");
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js
new file mode 100644
index 000000000..ab3d49c27
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/compute.js
@@ -0,0 +1,693 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+import { Table } from "../util/widgets.js"
+
+const MAX_HISTORY_POINTS = 60;
+
+// Windows FILETIME: 100ns ticks since 1601-01-01
+const FILETIME_EPOCH_OFFSET_MS = 11644473600000n;
+function filetimeToDate(ticks)
+{
+ if (!ticks) return null;
+ const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS;
+ return new Date(Number(ms));
+}
+
+function formatTime(date)
+{
+ if (!date) return "-";
+ return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" });
+}
+
+function formatDuration(startDate, endDate)
+{
+ if (!startDate || !endDate) return "-";
+ const ms = endDate - startDate;
+ if (ms < 0) return "-";
+ if (ms < 1000) return ms + " ms";
+ if (ms < 60000) return (ms / 1000).toFixed(2) + " s";
+ const m = Math.floor(ms / 60000);
+ const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, "0");
+ return `${m}m ${s}s`;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("compute");
+
+ this._history = { timestamps: [], pending: [], running: [], completed: [], cpu: [] };
+ this._selected_worker = null;
+ this._chart_js = null;
+ this._queue_chart = null;
+ this._cpu_chart = null;
+
+ this._ws_paused = false;
+ try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) {}
+ document.addEventListener("zen-ws-toggle", (e) => {
+ this._ws_paused = e.detail.paused;
+ });
+
+ // Action Queue section
+ const queue_section = this._collapsible_section("Action Queue");
+ this._queue_grid = queue_section.tag().classify("grid").classify("stats-tiles");
+ this._chart_host = queue_section;
+
+ // Performance Metrics section
+ const perf_section = this._collapsible_section("Performance Metrics");
+ this._perf_host = perf_section;
+ this._perf_grid = null;
+
+ // Workers section
+ const workers_section = this._collapsible_section("Workers");
+ this._workers_host = workers_section;
+ this._workers_table = null;
+ this._worker_detail_container = null;
+
+ // Queues section
+ const queues_section = this._collapsible_section("Queues");
+ this._queues_host = queues_section;
+ this._queues_table = null;
+
+ // Action History section
+ const history_section = this._collapsible_section("Recent Actions");
+ this._history_host = history_section;
+ this._history_table = null;
+
+ // System Resources section
+ const sys_section = this._collapsible_section("System Resources");
+ this._sys_grid = sys_section.tag().classify("grid").classify("stats-tiles");
+
+ // Load Chart.js dynamically
+ this._load_chartjs();
+
+ // Initial fetch
+ await this._fetch_all();
+
+ // Poll
+ this._poll_timer = setInterval(() => {
+ if (!this._ws_paused)
+ {
+ this._fetch_all();
+ }
+ }, 2000);
+ }
+
+ _collapsible_section(name)
+ {
+ const section = this.add_section(name);
+ const container = section._parent.inner();
+ const heading = container.firstElementChild;
+
+ heading.style.cursor = "pointer";
+ heading.style.userSelect = "none";
+
+ const indicator = document.createElement("span");
+ indicator.textContent = " \u25BC";
+ indicator.style.fontSize = "0.7em";
+ heading.appendChild(indicator);
+
+ let collapsed = false;
+ heading.addEventListener("click", (e) => {
+ if (e.target !== heading && e.target !== indicator)
+ {
+ return;
+ }
+ collapsed = !collapsed;
+ indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
+ let sibling = heading.nextElementSibling;
+ while (sibling)
+ {
+ sibling.style.display = collapsed ? "none" : "";
+ sibling = sibling.nextElementSibling;
+ }
+ });
+
+ return section;
+ }
+
+ async _load_chartjs()
+ {
+ if (window.Chart)
+ {
+ this._chart_js = window.Chart;
+ this._init_charts();
+ return;
+ }
+
+ try
+ {
+ const script = document.createElement("script");
+ script.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js";
+ script.onload = () => {
+ this._chart_js = window.Chart;
+ this._init_charts();
+ };
+ document.head.appendChild(script);
+ }
+ catch (e) { /* Chart.js not available */ }
+ }
+
+ _init_charts()
+ {
+ if (!this._chart_js)
+ {
+ return;
+ }
+
+ // Queue history chart
+ {
+ const card = this._chart_host.tag().classify("card");
+ card.tag().classify("card-title").text("Action Queue History");
+ const container = card.tag();
+ container.style("position", "relative").style("height", "300px").style("marginTop", "20px");
+ const canvas = document.createElement("canvas");
+ container.inner().appendChild(canvas);
+
+ this._queue_chart = new this._chart_js(canvas.getContext("2d"), {
+ type: "line",
+ data: {
+ labels: [],
+ datasets: [
+ { label: "Pending", data: [], borderColor: "#f0883e", backgroundColor: "rgba(240, 136, 62, 0.1)", tension: 0.4, fill: true },
+ { label: "Running", data: [], borderColor: "#58a6ff", backgroundColor: "rgba(88, 166, 255, 0.1)", tension: 0.4, fill: true },
+ { label: "Completed", data: [], borderColor: "#3fb950", backgroundColor: "rgba(63, 185, 80, 0.1)", tension: 0.4, fill: true },
+ ]
+ },
+ options: {
+ responsive: true,
+ maintainAspectRatio: false,
+ plugins: { legend: { display: true, labels: { color: "#8b949e" } } },
+ scales: { x: { display: false }, y: { beginAtZero: true, ticks: { color: "#8b949e" }, grid: { color: "#21262d" } } }
+ }
+ });
+ }
+
+ // CPU sparkline (will be appended to CPU card later)
+ this._cpu_canvas = document.createElement("canvas");
+ this._cpu_chart = new this._chart_js(this._cpu_canvas.getContext("2d"), {
+ type: "line",
+ data: {
+ labels: [],
+ datasets: [{
+ data: [],
+ borderColor: "#58a6ff",
+ backgroundColor: "rgba(88, 166, 255, 0.15)",
+ borderWidth: 1.5,
+ tension: 0.4,
+ fill: true,
+ pointRadius: 0
+ }]
+ },
+ options: {
+ responsive: true,
+ maintainAspectRatio: false,
+ animation: false,
+ plugins: { legend: { display: false }, tooltip: { enabled: false } },
+ scales: { x: { display: false }, y: { display: false, min: 0, max: 100 } }
+ }
+ });
+ }
+
+ async _fetch_all()
+ {
+ try
+ {
+ const [stats, sysinfo, workers_data, queues_data, history_data] = await Promise.all([
+ new Fetcher().resource("/stats/compute").json().catch(() => null),
+ new Fetcher().resource("/compute/sysinfo").json().catch(() => null),
+ new Fetcher().resource("/compute/workers").json().catch(() => null),
+ new Fetcher().resource("/compute/queues").json().catch(() => null),
+ new Fetcher().resource("/compute/jobs/history").param("limit", "50").json().catch(() => null),
+ ]);
+
+ if (stats)
+ {
+ this._render_queue_stats(stats);
+ this._update_queue_chart(stats);
+ this._render_perf(stats);
+ }
+ if (sysinfo)
+ {
+ this._render_sysinfo(sysinfo);
+ }
+ if (workers_data)
+ {
+ this._render_workers(workers_data);
+ }
+ if (queues_data)
+ {
+ this._render_queues(queues_data);
+ }
+ if (history_data)
+ {
+ this._render_action_history(history_data);
+ }
+ }
+ catch (e) { /* service unavailable */ }
+ }
+
+ _render_queue_stats(data)
+ {
+ const grid = this._queue_grid;
+ grid.inner().innerHTML = "";
+
+ const tiles = [
+ { title: "Pending Actions", value: data.actions_pending || 0, label: "waiting to be scheduled" },
+ { title: "Running Actions", value: data.actions_submitted || 0, label: "currently executing" },
+ { title: "Completed Actions", value: data.actions_complete || 0, label: "results available" },
+ ];
+
+ for (const t of tiles)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text(t.title);
+ const body = tile.tag().classify("tile-metrics");
+ this._metric(body, Friendly.sep(t.value), t.label, true);
+ }
+ }
+
+ _update_queue_chart(data)
+ {
+ const h = this._history;
+ h.timestamps.push(new Date().toLocaleTimeString());
+ h.pending.push(data.actions_pending || 0);
+ h.running.push(data.actions_submitted || 0);
+ h.completed.push(data.actions_complete || 0);
+
+ while (h.timestamps.length > MAX_HISTORY_POINTS)
+ {
+ h.timestamps.shift();
+ h.pending.shift();
+ h.running.shift();
+ h.completed.shift();
+ }
+
+ if (this._queue_chart)
+ {
+ this._queue_chart.data.labels = h.timestamps;
+ this._queue_chart.data.datasets[0].data = h.pending;
+ this._queue_chart.data.datasets[1].data = h.running;
+ this._queue_chart.data.datasets[2].data = h.completed;
+ this._queue_chart.update("none");
+ }
+ }
+
+ _render_perf(data)
+ {
+ if (!this._perf_grid)
+ {
+ this._perf_grid = this._perf_host.tag().classify("grid").classify("stats-tiles");
+ }
+ const grid = this._perf_grid;
+ grid.inner().innerHTML = "";
+
+ const retired = data.actions_retired || {};
+
+ // Completion rate card
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Completion Rate");
+ const body = tile.tag().classify("tile-columns");
+
+ const left = body.tag().classify("tile-metrics");
+ this._metric(left, this._fmt_rate(retired.rate_1), "1 min rate", true);
+ this._metric(left, this._fmt_rate(retired.rate_5), "5 min rate");
+ this._metric(left, this._fmt_rate(retired.rate_15), "15 min rate");
+
+ const right = body.tag().classify("tile-metrics");
+ this._metric(right, Friendly.sep(retired.count || 0), "total retired", true);
+ this._metric(right, this._fmt_rate(retired.rate_mean), "mean rate");
+ }
+ }
+
+ _fmt_rate(rate)
+ {
+ if (rate == null) return "-";
+ return rate.toFixed(2) + "/s";
+ }
+
+ _render_workers(data)
+ {
+ const workerIds = data.workers || [];
+
+ if (this._workers_table)
+ {
+ this._workers_table.clear();
+ }
+ else
+ {
+ this._workers_table = this._workers_host.add_widget(
+ Table,
+ ["name", "platform", "cores", "timeout", "functions", "worker ID"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1
+ );
+ }
+
+ if (workerIds.length === 0)
+ {
+ return;
+ }
+
+ // Fetch each worker's descriptor
+ Promise.all(
+ workerIds.map(id =>
+ new Fetcher().resource("/compute/workers", id).json()
+ .then(desc => ({ id, desc }))
+ .catch(() => ({ id, desc: null }))
+ )
+ ).then(results => {
+ this._workers_table.clear();
+ for (const { id, desc } of results)
+ {
+ const name = desc ? (desc.name || "-") : "-";
+ const host = desc ? (desc.host || "-") : "-";
+ const cores = desc ? (desc.cores != null ? desc.cores : "-") : "-";
+ const timeout = desc ? (desc.timeout != null ? desc.timeout + "s" : "-") : "-";
+ const functions = desc ? (desc.functions ? desc.functions.length : 0) : "-";
+
+ const row = this._workers_table.add_row(
+ "",
+ host,
+ String(cores),
+ String(timeout),
+ String(functions),
+ id,
+ );
+
+ // Make name clickable to expand detail
+ const cell = row.get_cell(0);
+ cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc));
+
+ // Highlight selected
+ if (id === this._selected_worker)
+ {
+ row.style("background", "var(--theme_p3)");
+ }
+ }
+
+ this._worker_descriptors = Object.fromEntries(results.map(r => [r.id, r.desc]));
+
+ // Re-render detail if still selected
+ if (this._selected_worker && this._worker_descriptors[this._selected_worker])
+ {
+ this._show_worker_detail(this._selected_worker, this._worker_descriptors[this._selected_worker]);
+ }
+ else if (this._selected_worker)
+ {
+ this._selected_worker = null;
+ this._clear_worker_detail();
+ }
+ });
+ }
+
+ _toggle_worker_detail(id, desc)
+ {
+ if (this._selected_worker === id)
+ {
+ this._selected_worker = null;
+ this._clear_worker_detail();
+ return;
+ }
+ this._selected_worker = id;
+ this._show_worker_detail(id, desc);
+ }
+
+ _clear_worker_detail()
+ {
+ if (this._worker_detail_container)
+ {
+ this._worker_detail_container._parent.inner().remove();
+ this._worker_detail_container = null;
+ }
+ }
+
+ _show_worker_detail(id, desc)
+ {
+ this._clear_worker_detail();
+ if (!desc)
+ {
+ return;
+ }
+
+ const section = this._workers_host.add_section(desc.name || id);
+ this._worker_detail_container = section;
+
+ // Basic info table
+ const info_table = section.add_widget(
+ Table, ["property", "value"], Table.Flag_FitLeft|Table.Flag_PackRight
+ );
+ const fields = [
+ ["Worker ID", id],
+ ["Path", desc.path || "-"],
+ ["Platform", desc.host || "-"],
+ ["Build System", desc.buildsystem_version || "-"],
+ ["Cores", desc.cores != null ? String(desc.cores) : "-"],
+ ["Timeout", desc.timeout != null ? desc.timeout + "s" : "-"],
+ ];
+ for (const [label, value] of fields)
+ {
+ info_table.add_row(label, value);
+ }
+
+ // Functions
+ const functions = desc.functions || [];
+ if (functions.length > 0)
+ {
+ const fn_section = section.add_section("Functions");
+ const fn_table = fn_section.add_widget(
+ Table, ["name", "version"], Table.Flag_FitLeft|Table.Flag_PackRight
+ );
+ for (const f of functions)
+ {
+ fn_table.add_row(f.name || "-", f.version || "-");
+ }
+ }
+
+ // Executables
+ const executables = desc.executables || [];
+ if (executables.length > 0)
+ {
+ const exec_section = section.add_section("Executables");
+ const exec_table = exec_section.add_widget(
+ Table, ["path", "hash", "size"], Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_AlignNumeric
+ );
+ let totalSize = 0;
+ for (const e of executables)
+ {
+ exec_table.add_row(e.name || "-", e.hash || "-", e.size != null ? Friendly.bytes(e.size) : "-");
+ totalSize += e.size || 0;
+ }
+ const total_row = exec_table.add_row("TOTAL", "", Friendly.bytes(totalSize));
+ total_row.get_cell(0).style("fontWeight", "bold");
+ total_row.get_cell(2).style("fontWeight", "bold");
+ }
+
+ // Files
+ const files = desc.files || [];
+ if (files.length > 0)
+ {
+ const files_section = section.add_section("Files");
+ const files_table = files_section.add_widget(
+ Table, ["name", "hash"], Table.Flag_FitLeft|Table.Flag_PackRight
+ );
+ for (const f of files)
+ {
+ files_table.add_row(typeof f === "string" ? f : (f.name || "-"), typeof f === "string" ? "" : (f.hash || ""));
+ }
+ }
+
+ // Directories
+ const dirs = desc.dirs || [];
+ if (dirs.length > 0)
+ {
+ const dirs_section = section.add_section("Directories");
+ for (const d of dirs)
+ {
+ dirs_section.tag().classify("detail-tag").text(d);
+ }
+ }
+
+ // Environment
+ const env = desc.environment || [];
+ if (env.length > 0)
+ {
+ const env_section = section.add_section("Environment");
+ for (const e of env)
+ {
+ env_section.tag().classify("detail-tag").text(e);
+ }
+ }
+ }
+
+ _render_queues(data)
+ {
+ const queues = data.queues || [];
+
+ if (this._queues_table)
+ {
+ this._queues_table.clear();
+ }
+ else
+ {
+ this._queues_table = this._queues_host.add_widget(
+ Table,
+ ["ID", "status", "active", "completed", "failed", "abandoned", "cancelled", "token"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1
+ );
+ }
+
+ for (const q of queues)
+ {
+ const id = q.queue_id != null ? String(q.queue_id) : "-";
+ const status = q.state === "cancelled" ? "cancelled"
+ : q.state === "draining" ? "draining"
+ : q.is_complete ? "complete" : "active";
+
+ this._queues_table.add_row(
+ id,
+ status,
+ String(q.active_count ?? 0),
+ String(q.completed_count ?? 0),
+ String(q.failed_count ?? 0),
+ String(q.abandoned_count ?? 0),
+ String(q.cancelled_count ?? 0),
+ q.queue_token || "-",
+ );
+ }
+ }
+
+ _render_action_history(data)
+ {
+ const entries = data.history || [];
+
+ if (this._history_table)
+ {
+ this._history_table.clear();
+ }
+ else
+ {
+ this._history_table = this._history_host.add_widget(
+ Table,
+ ["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1
+ );
+ }
+
+ // Entries arrive oldest-first; reverse to show newest at top
+ for (const entry of [...entries].reverse())
+ {
+ const lsn = entry.lsn != null ? String(entry.lsn) : "-";
+ const queueId = entry.queueId ? String(entry.queueId) : "-";
+ const status = entry.succeeded == null ? "unknown"
+ : entry.succeeded ? "ok" : "failed";
+ const desc = entry.actionDescriptor || {};
+ const fn = desc.Function || "-";
+ const startDate = filetimeToDate(entry.time_Running);
+ const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed);
+
+ this._history_table.add_row(
+ lsn,
+ queueId,
+ status,
+ fn,
+ formatTime(startDate),
+ formatTime(endDate),
+ formatDuration(startDate, endDate),
+ entry.workerId || "-",
+ entry.actionId || "-",
+ );
+ }
+ }
+
+ _render_sysinfo(data)
+ {
+ const grid = this._sys_grid;
+ grid.inner().innerHTML = "";
+
+ // CPU card
+ {
+ const cpuUsage = data.cpu_usage || 0;
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("CPU Usage");
+ const body = tile.tag().classify("tile-metrics");
+ this._metric(body, cpuUsage.toFixed(1) + "%", "percent", true);
+
+ // Progress bar
+ const bar = body.tag().classify("progress-bar");
+ bar.tag().classify("progress-fill").style("width", cpuUsage + "%");
+
+ // CPU sparkline
+ this._history.cpu.push(cpuUsage);
+ while (this._history.cpu.length > MAX_HISTORY_POINTS) this._history.cpu.shift();
+ if (this._cpu_chart)
+ {
+ const sparkContainer = body.tag();
+ sparkContainer.style("position", "relative").style("height", "60px").style("marginTop", "12px");
+ sparkContainer.inner().appendChild(this._cpu_canvas);
+
+ this._cpu_chart.data.labels = this._history.cpu.map(() => "");
+ this._cpu_chart.data.datasets[0].data = this._history.cpu;
+ this._cpu_chart.update("none");
+ }
+
+ // CPU details
+ this._stat_row(body, "Packages", data.cpu_count != null ? String(data.cpu_count) : "-");
+ this._stat_row(body, "Physical Cores", data.core_count != null ? String(data.core_count) : "-");
+ this._stat_row(body, "Logical Processors", data.lp_count != null ? String(data.lp_count) : "-");
+ }
+
+ // Memory card
+ {
+ const memUsed = data.memory_used || 0;
+ const memTotal = data.memory_total || 1;
+ const memPercent = (memUsed / memTotal) * 100;
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Memory");
+ const body = tile.tag().classify("tile-metrics");
+ this._stat_row(body, "Used", Friendly.bytes(memUsed));
+ this._stat_row(body, "Total", Friendly.bytes(memTotal));
+ const bar = body.tag().classify("progress-bar");
+ bar.tag().classify("progress-fill").style("width", memPercent + "%");
+ }
+
+ // Disk card
+ {
+ const diskUsed = data.disk_used || 0;
+ const diskTotal = data.disk_total || 1;
+ const diskPercent = (diskUsed / diskTotal) * 100;
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Disk");
+ const body = tile.tag().classify("tile-metrics");
+ this._stat_row(body, "Used", Friendly.bytes(diskUsed));
+ this._stat_row(body, "Total", Friendly.bytes(diskTotal));
+ const bar = body.tag().classify("progress-bar");
+ bar.tag().classify("progress-fill").style("width", diskPercent + "%");
+ }
+ }
+
+ _stat_row(parent, label, value)
+ {
+ const row = parent.tag().classify("stats-row");
+ row.tag().classify("stats-label").text(label);
+ row.tag().classify("stats-value").text(value);
+ }
+
+ _metric(parent, value, label, hero = false)
+ {
+ const m = parent.tag().classify("tile-metric");
+ if (hero)
+ {
+ m.classify("tile-metric-hero");
+ }
+ m.tag().classify("metric-value").text(value);
+ m.tag().classify("metric-label").text(label);
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/cookartifacts.js b/src/zenserver/frontend/html/pages/cookartifacts.js
new file mode 100644
index 000000000..f2ae094b9
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/cookartifacts.js
@@ -0,0 +1,397 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Table, Toolbar, PropTable } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ main()
+ {
+ this.set_title("cook artifacts");
+
+ const project = this.get_param("project");
+ const oplog = this.get_param("oplog");
+ const opkey = this.get_param("opkey");
+ const artifact_hash = this.get_param("hash");
+
+ // Fetch the artifact content as JSON
+ this._artifact = new Fetcher()
+ .resource("prj", project, "oplog", oplog, artifact_hash + ".json")
+ .json();
+
+ // Optionally fetch entry info for display context
+ if (opkey)
+ {
+ this._entry = new Fetcher()
+ .resource("prj", project, "oplog", oplog, "entries")
+ .param("opkey", opkey)
+ .cbo();
+ }
+
+ this._build_page();
+ }
+
+ // Map CookDependency enum values to display names
+ _get_dependency_type_name(type_value)
+ {
+ const type_names = {
+ 0: "None",
+ 1: "File",
+ 2: "Function",
+ 3: "TransitiveBuild",
+ 4: "Package",
+ 5: "ConsoleVariable",
+ 6: "Config",
+ 7: "SettingsObject",
+ 8: "NativeClass",
+ 9: "AssetRegistryQuery",
+ 10: "RedirectionTarget"
+ };
+ return type_names[type_value] || `Unknown (${type_value})`;
+ }
+
+ // Check if Data content should be expandable
+ _should_make_expandable(data_string)
+ {
+ if (!data_string || data_string.length < 40)
+ return false;
+
+ // Check if it's JSON array or object
+ if (!data_string.startsWith('[') && !data_string.startsWith('{'))
+ return false;
+
+ // Check if formatting would add newlines
+ try {
+ const parsed = JSON.parse(data_string);
+ const formatted = JSON.stringify(parsed, null, 2);
+ return formatted.includes('\n');
+ } catch (e) {
+ return false;
+ }
+ }
+
+ // Get first line of content for collapsed state
+ _get_first_line(data_string)
+ {
+ if (!data_string)
+ return "";
+
+ const newline_index = data_string.indexOf('\n');
+ if (newline_index === -1)
+ {
+ // No newline, truncate if too long
+ return data_string.length > 80 ? data_string.substring(0, 77) + "..." : data_string;
+ }
+ return data_string.substring(0, newline_index) + "...";
+ }
+
+ // Format JSON with indentation
+ _format_json(data_string)
+ {
+ try {
+ const parsed = JSON.parse(data_string);
+ return JSON.stringify(parsed, null, 2);
+ } catch (e) {
+ return data_string;
+ }
+ }
+
+ // Toggle expand/collapse state
+ _toggle_data_cell(cell)
+ {
+ const is_expanded = cell.attr("expanded") !== null;
+ const full_data = cell.attr("data-full");
+
+ // Find the text wrapper span
+ const text_wrapper = cell.first_child().next_sibling();
+
+ if (is_expanded)
+ {
+ // Collapse: show first line only
+ const first_line = this._get_first_line(full_data);
+ text_wrapper.text(first_line);
+ cell.attr("expanded", null);
+ }
+ else
+ {
+ // Expand: show formatted JSON
+ const formatted = this._format_json(full_data);
+ text_wrapper.text(formatted);
+ cell.attr("expanded", "");
+ }
+ }
+
+ // Format dependency data based on its structure
+ _format_dependency(dep_array)
+ {
+ const type = dep_array[0];
+ const formatted = {};
+
+ // Common patterns based on the example data:
+ // Type 2 (Function): [type, name, array, hash]
+ // Type 4 (Package): [type, path, hash]
+ // Type 5 (ConsoleVariable): [type, bool, array, hash]
+ // Type 8 (NativeClass): [type, path, hash]
+ // Type 9 (AssetRegistryQuery): [type, bool, object, hash]
+ // Type 10 (RedirectionTarget): [type, path, hash]
+
+ if (dep_array.length > 1)
+ {
+ // Most types have a name/path as second element
+ if (typeof dep_array[1] === "string")
+ {
+ formatted.Name = dep_array[1];
+ }
+ else if (typeof dep_array[1] === "boolean")
+ {
+ formatted.Value = dep_array[1].toString();
+ }
+ }
+
+ if (dep_array.length > 2)
+ {
+ // Third element varies
+ if (Array.isArray(dep_array[2]))
+ {
+ formatted.Data = JSON.stringify(dep_array[2]);
+ }
+ else if (typeof dep_array[2] === "object")
+ {
+ formatted.Data = JSON.stringify(dep_array[2]);
+ }
+ else if (typeof dep_array[2] === "string")
+ {
+ formatted.Hash = dep_array[2];
+ }
+ }
+
+ if (dep_array.length > 3)
+ {
+ // Fourth element is usually the hash
+ if (typeof dep_array[3] === "string")
+ {
+ formatted.Hash = dep_array[3];
+ }
+ }
+
+ return formatted;
+ }
+
+ async _build_page()
+ {
+ const project = this.get_param("project");
+ const oplog = this.get_param("oplog");
+ const opkey = this.get_param("opkey");
+ const artifact_hash = this.get_param("hash");
+
+ // Build page title
+ let title = "Cook Artifacts";
+ if (this._entry)
+ {
+ try
+ {
+ const entry = await this._entry;
+ const entry_obj = entry.as_object().find("entry").as_object();
+ const key = entry_obj.find("key").as_value();
+ title = `Cook Artifacts`;
+ }
+ catch (e)
+ {
+ console.error("Failed to fetch entry:", e);
+ }
+ }
+
+ const section = this.add_section(title);
+
+ // Fetch and parse artifact
+ let artifact;
+ try
+ {
+ artifact = await this._artifact;
+ }
+ catch (e)
+ {
+ section.text(`Failed to load artifact: ${e.message}`);
+ return;
+ }
+
+ // Display artifact info
+ const info_section = section.add_section("Artifact Info");
+ const info_table = info_section.add_widget(Table, ["Property", "Value"], Table.Flag_PackRight);
+
+ if (artifact.Version !== undefined)
+ info_table.add_row("Version", artifact.Version.toString());
+ if (artifact.HasSaveResults !== undefined)
+ info_table.add_row("HasSaveResults", artifact.HasSaveResults.toString());
+ if (artifact.PackageSavedHash !== undefined)
+ info_table.add_row("PackageSavedHash", artifact.PackageSavedHash);
+
+ // Process SaveBuildDependencies
+ if (artifact.SaveBuildDependencies && artifact.SaveBuildDependencies.Dependencies)
+ {
+ this._build_dependency_section(
+ section,
+ "Save Build Dependencies",
+ artifact.SaveBuildDependencies.Dependencies,
+ artifact.SaveBuildDependencies.StoredKey
+ );
+ }
+
+ // Process LoadBuildDependencies
+ if (artifact.LoadBuildDependencies && artifact.LoadBuildDependencies.Dependencies)
+ {
+ this._build_dependency_section(
+ section,
+ "Load Build Dependencies",
+ artifact.LoadBuildDependencies.Dependencies,
+ artifact.LoadBuildDependencies.StoredKey
+ );
+ }
+
+ // Process RuntimeDependencies
+ if (artifact.RuntimeDependencies && artifact.RuntimeDependencies.length > 0)
+ {
+ const runtime_section = section.add_section("Runtime Dependencies");
+ const runtime_table = runtime_section.add_widget(Table, ["Path"], Table.Flag_PackRight);
+ for (const dep of artifact.RuntimeDependencies)
+ {
+ const row = runtime_table.add_row(dep);
+ // Make Path clickable to navigate to entry
+ if (this._should_link_dependency(dep))
+ {
+ row.get_cell(0).text(dep).on_click((opkey) => {
+ window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`;
+ }, dep);
+ }
+ }
+ }
+ }
+
+ _should_link_dependency(name)
+ {
+ // Exclude dependencies starting with /Script/ (code-defined entries) - case insensitive
+ if (name && name.toLowerCase().startsWith("/script/"))
+ return false;
+
+ return true;
+ }
+
+ _build_dependency_section(parent_section, title, dependencies, stored_key)
+ {
+ const section = parent_section.add_section(title);
+
+ // Add stored key info
+ if (stored_key)
+ {
+ const key_toolbar = section.add_widget(Toolbar);
+ key_toolbar.left().add(`Key: ${stored_key}`);
+ }
+
+ // Group dependencies by type
+ const dependencies_by_type = {};
+
+ for (const dep_array of dependencies)
+ {
+ if (!Array.isArray(dep_array) || dep_array.length === 0)
+ continue;
+
+ const type = dep_array[0];
+ if (!dependencies_by_type[type])
+ dependencies_by_type[type] = [];
+
+ dependencies_by_type[type].push(this._format_dependency(dep_array));
+ }
+
+ // Sort types numerically
+ const sorted_types = Object.keys(dependencies_by_type).map(Number).sort((a, b) => a - b);
+
+ for (const type_value of sorted_types)
+ {
+ const type_name = this._get_dependency_type_name(type_value);
+ const deps = dependencies_by_type[type_value];
+
+ const type_section = section.add_section(type_name);
+
+ // Determine columns based on available fields
+ const all_fields = new Set();
+ for (const dep of deps)
+ {
+ for (const field in dep)
+ all_fields.add(field);
+ }
+ let columns = Array.from(all_fields);
+
+ // Remove Hash column for RedirectionTarget as it's not useful
+ if (type_value === 10)
+ {
+ columns = columns.filter(col => col !== "Hash");
+ }
+
+ if (columns.length === 0)
+ {
+ type_section.text("No data fields");
+ continue;
+ }
+
+ // Create table with dynamic columns
+ const table = type_section.add_widget(Table, columns, Table.Flag_PackRight);
+
+ // Check if this type should have clickable Name links
+ const should_link = (type_value === 3 || type_value === 4 || type_value === 10);
+ const name_col_index = columns.indexOf("Name");
+
+ for (const dep of deps)
+ {
+ const row_values = columns.map(col => dep[col] || "");
+ const row = table.add_row(...row_values);
+
+ // Make Name field clickable for Package, TransitiveBuild, and RedirectionTarget
+ if (should_link && name_col_index >= 0 && dep.Name && this._should_link_dependency(dep.Name))
+ {
+ const project = this.get_param("project");
+ const oplog = this.get_param("oplog");
+ row.get_cell(name_col_index).text(dep.Name).on_click((opkey) => {
+ window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`;
+ }, dep.Name);
+ }
+
+ // Make Data field expandable/collapsible if needed
+ const data_col_index = columns.indexOf("Data");
+ if (data_col_index >= 0 && dep.Data)
+ {
+ const data_cell = row.get_cell(data_col_index);
+
+ if (this._should_make_expandable(dep.Data))
+ {
+ // Store full data in attribute
+ data_cell.attr("data-full", dep.Data);
+
+ // Clear the cell and rebuild with icon + text
+ data_cell.inner().innerHTML = "";
+
+ // Create expand/collapse icon
+ const icon = data_cell.tag("span").classify("zen_expand_icon").text("+");
+ icon.on_click(() => {
+ this._toggle_data_cell(data_cell);
+ // Update icon text
+ const is_expanded = data_cell.attr("expanded") !== null;
+ icon.text(is_expanded ? "-" : "+");
+ });
+
+ // Add text content wrapper
+ const text_wrapper = data_cell.tag("span").classify("zen_data_text");
+ const first_line = this._get_first_line(dep.Data);
+ text_wrapper.text(first_line);
+
+ // Store reference to text wrapper for updates
+ data_cell.attr("data-text-wrapper", "true");
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js
index 08589b090..1e4c82e3f 100644
--- a/src/zenserver/frontend/html/pages/entry.js
+++ b/src/zenserver/frontend/html/pages/entry.js
@@ -26,6 +26,9 @@ export class Page extends ZenPage
this._indexer = this.load_indexer(project, oplog);
+ this._files_index_start = Number(this.get_param("files_start", 0)) || 0;
+ this._files_index_count = Number(this.get_param("files_count", 50)) || 0;
+
this._build_page();
}
@@ -40,25 +43,39 @@ export class Page extends ZenPage
return indexer;
}
- async _build_deps(section, tree)
+ _build_deps(section, tree)
{
- const indexer = await this._indexer;
+ const project = this.get_param("project");
+ const oplog = this.get_param("oplog");
for (const dep_name in tree)
{
const dep_section = section.add_section(dep_name);
const table = dep_section.add_widget(Table, ["name", "id"], Table.Flag_PackRight);
+
for (const dep_id of tree[dep_name])
{
- const cell_values = ["", dep_id.toString(16).padStart(16, "0")];
+ const hex_id = dep_id.toString(16).padStart(16, "0");
+ const cell_values = ["loading...", hex_id];
const row = table.add_row(...cell_values);
- var opkey = indexer.lookup_id(dep_id);
- row.get_cell(0).text(opkey).on_click((k) => this.view_opkey(k), opkey);
+ // Asynchronously resolve the name
+ this._resolve_dep_name(row.get_cell(0), dep_id, project, oplog);
}
}
}
+ async _resolve_dep_name(cell, dep_id, project, oplog)
+ {
+ const indexer = await this._indexer;
+ const opkey = indexer.lookup_id(dep_id);
+
+ if (opkey)
+ {
+ cell.text(opkey).on_click((k) => this.view_opkey(k), opkey);
+ }
+ }
+
_find_iohash_field(container, name)
{
const found_field = container.find(name);
@@ -76,6 +93,21 @@ export class Page extends ZenPage
return null;
}
+ _is_null_io_hash_string(io_hash)
+ {
+ if (!io_hash)
+ return true;
+
+ for (let char of io_hash)
+ {
+ if (char != '0')
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
async _build_meta(section, entry)
{
var tree = {}
@@ -123,11 +155,23 @@ export class Page extends ZenPage
const project = this.get_param("project");
const oplog = this.get_param("oplog");
+ const opkey = this.get_param("opkey");
const link = row.get_cell(0).link(
- "/" + ["prj", project, "oplog", oplog, value+".json"].join("/")
+ (key === "cook.artifacts") ?
+ `?page=cookartifacts&project=${project}&oplog=${oplog}&opkey=${opkey}&hash=${value}`
+ : "/" + ["prj", project, "oplog", oplog, value+".json"].join("/")
);
const action_tb = new Toolbar(row.get_cell(-1), true);
+
+ // Add "view-raw" button for cook.artifacts
+ if (key === "cook.artifacts")
+ {
+ action_tb.left().add("view-raw").on_click(() => {
+ window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/");
+ });
+ }
+
action_tb.left().add("copy-hash").on_click(async (v) => {
await navigator.clipboard.writeText(v);
}, value);
@@ -137,35 +181,55 @@ export class Page extends ZenPage
async _build_page()
{
var entry = await this._entry;
+
+ // Check if entry exists
+ if (!entry || entry.as_object().find("entry") == null)
+ {
+ const opkey = this.get_param("opkey");
+ var section = this.add_section("Entry Not Found");
+ section.tag("p").text(`The entry "${opkey}" is not present in this dataset.`);
+ section.tag("p").text("This could mean:");
+ const list = section.tag("ul");
+ list.tag("li").text("The entry is for an instance defined in code");
+ list.tag("li").text("The entry has not been added to the oplog yet");
+ list.tag("li").text("The entry key is misspelled");
+ list.tag("li").text("The entry was removed or never existed");
+ return;
+ }
+
entry = entry.as_object().find("entry").as_object();
const name = entry.find("key").as_value();
var section = this.add_section(name);
+ var has_package_data = false;
// tree
{
var tree = entry.find("$tree");
if (tree == undefined)
tree = this._convert_legacy_to_tree(entry);
- if (tree == undefined)
- return this._display_unsupported(section, entry);
-
- delete tree["$id"];
-
- if (Object.keys(tree).length != 0)
+ if (tree != undefined)
{
- const sub_section = section.add_section("deps");
- this._build_deps(sub_section, tree);
+ delete tree["$id"];
+
+ if (Object.keys(tree).length != 0)
+ {
+ const sub_section = section.add_section("dependencies");
+ this._build_deps(sub_section, tree);
+ }
+ has_package_data = true;
}
}
// meta
+ if (has_package_data)
{
this._build_meta(section, entry);
}
// data
+ if (has_package_data)
{
const sub_section = section.add_section("data");
const table = sub_section.add_widget(
@@ -181,7 +245,7 @@ export class Page extends ZenPage
for (const item of pkg_data.as_array())
{
- var io_hash, size, raw_size, file_name;
+ var io_hash = undefined, size = undefined, raw_size = undefined, file_name = undefined;
for (const field of item.as_object())
{
if (field.is_named("data")) io_hash = field.as_value();
@@ -198,8 +262,8 @@ export class Page extends ZenPage
io_hash = ret;
}
- size = (size !== undefined) ? Friendly.kib(size) : "";
- raw_size = (raw_size !== undefined) ? Friendly.kib(raw_size) : "";
+ size = (size !== undefined) ? Friendly.bytes(size) : "";
+ raw_size = (raw_size !== undefined) ? Friendly.bytes(raw_size) : "";
const row = table.add_row(file_name, size, raw_size);
@@ -219,12 +283,76 @@ export class Page extends ZenPage
}
}
+ // files
+ var has_file_data = false;
+ {
+ var file_data = entry.find("files");
+ if (file_data != undefined)
+ {
+ has_file_data = true;
+
+ // Extract files into array
+ this._files_data = [];
+ for (const item of file_data.as_array())
+ {
+ var io_hash = undefined, cid = undefined, server_path = undefined, client_path = undefined;
+ for (const field of item.as_object())
+ {
+ if (field.is_named("data")) io_hash = field.as_value();
+ else if (field.is_named("id")) cid = field.as_value();
+ else if (field.is_named("serverpath")) server_path = field.as_value();
+ else if (field.is_named("clientpath")) client_path = field.as_value();
+ }
+
+ if (io_hash instanceof Uint8Array)
+ {
+ var ret = "";
+ for (var x of io_hash)
+ ret += x.toString(16).padStart(2, "0");
+ io_hash = ret;
+ }
+
+ if (cid instanceof Uint8Array)
+ {
+ var ret = "";
+ for (var x of cid)
+ ret += x.toString(16).padStart(2, "0");
+ cid = ret;
+ }
+
+ this._files_data.push({
+ server_path: server_path,
+ client_path: client_path,
+ io_hash: io_hash,
+ cid: cid
+ });
+ }
+
+ this._files_index_max = this._files_data.length;
+
+ const sub_section = section.add_section("files");
+ this._build_files_nav(sub_section);
+
+ this._files_table = sub_section.add_widget(
+ Table,
+ ["name", "actions"], Table.Flag_PackRight
+ );
+ this._files_table.id("filetable");
+
+ this._build_files_table(this._files_index_start);
+ }
+ }
+
// props
+ if (has_package_data)
{
const object = entry.to_js_object();
var sub_section = section.add_section("props");
sub_section.add_widget(PropTable).add_object(object);
}
+
+ if (!has_package_data && !has_file_data)
+ return this._display_unsupported(section, entry);
}
_display_unsupported(section, entry)
@@ -271,16 +399,30 @@ export class Page extends ZenPage
for (const field of pkgst_entry)
{
const field_name = field.get_name();
- if (!field_name.endsWith("importedpackageids"))
- continue;
-
- var dep_name = field_name.slice(0, -18);
- if (dep_name.length == 0)
- dep_name = "imported";
-
- var out = tree[dep_name] = [];
- for (var item of field.as_array())
- out.push(item.as_value(BigInt));
+ if (field_name.endsWith("importedpackageids"))
+ {
+ var dep_name = field_name.slice(0, -18);
+ if (dep_name.length == 0)
+ dep_name = "hard";
+ else
+ dep_name = "hard." + dep_name;
+
+ var out = tree[dep_name] = [];
+ for (var item of field.as_array())
+ out.push(item.as_value(BigInt));
+ }
+ else if (field_name.endsWith("softpackagereferences"))
+ {
+ var dep_name = field_name.slice(0, -21);
+ if (dep_name.length == 0)
+ dep_name = "soft";
+ else
+ dep_name = "soft." + dep_name;
+
+ var out = tree[dep_name] = [];
+ for (var item of field.as_array())
+ out.push(item.as_value(BigInt));
+ }
}
return tree;
@@ -292,4 +434,149 @@ export class Page extends ZenPage
params.set("opkey", opkey);
window.location.search = params;
}
+
+ _build_files_nav(section)
+ {
+ const nav = section.add_widget(Toolbar);
+ const left = nav.left();
+ left.add("|<") .on_click(() => this._on_files_next_prev(-10e10));
+ left.add("<<") .on_click(() => this._on_files_next_prev(-10));
+ left.add("prev").on_click(() => this._on_files_next_prev( -1));
+ left.add("next").on_click(() => this._on_files_next_prev( 1));
+ left.add(">>") .on_click(() => this._on_files_next_prev( 10));
+ left.add(">|") .on_click(() => this._on_files_next_prev( 10e10));
+
+ left.sep();
+ for (var count of [10, 25, 50, 100])
+ {
+ var handler = (n) => this._on_files_change_count(n);
+ left.add(count).on_click(handler, count);
+ }
+
+ const right = nav.right();
+ right.add(Friendly.sep(this._files_index_max));
+
+ right.sep();
+ var search_input = right.add("search:", "label").tag("input");
+ search_input.on("change", (x) => this._search_files(x.inner().value), search_input);
+ }
+
+ _build_files_table(index)
+ {
+ this._files_index_count = Math.max(this._files_index_count, 1);
+ index = Math.min(index, this._files_index_max - this._files_index_count);
+ index = Math.max(index, 0);
+
+ const project = this.get_param("project");
+ const oplog = this.get_param("oplog");
+
+ const end_index = Math.min(index + this._files_index_count, this._files_index_max);
+
+ this._files_table.clear(index);
+ for (var i = index; i < end_index; i++)
+ {
+ const file_item = this._files_data[i];
+ const row = this._files_table.add_row(file_item.server_path);
+
+ var base_name = file_item.server_path.split("/").pop().split("\\").pop();
+ if (this._is_null_io_hash_string(file_item.io_hash))
+ {
+ const link = row.get_cell(0).link(
+ "/" + ["prj", project, "oplog", oplog, file_item.cid].join("/")
+ );
+ link.first_child().attr("download", `${file_item.cid}_${base_name}`);
+
+ const action_tb = new Toolbar(row.get_cell(-1), true);
+ action_tb.left().add("copy-id").on_click(async (v) => {
+ await navigator.clipboard.writeText(v);
+ }, file_item.cid);
+ }
+ else
+ {
+ const link = row.get_cell(0).link(
+ "/" + ["prj", project, "oplog", oplog, file_item.io_hash].join("/")
+ );
+ link.first_child().attr("download", `${file_item.io_hash}_${base_name}`);
+
+ const action_tb = new Toolbar(row.get_cell(-1), true);
+ action_tb.left().add("copy-hash").on_click(async (v) => {
+ await navigator.clipboard.writeText(v);
+ }, file_item.io_hash);
+ }
+ }
+
+ this.set_param("files_start", index);
+ this.set_param("files_count", this._files_index_count);
+ this._files_index_start = index;
+ }
+
+ _on_files_change_count(value)
+ {
+ this._files_index_count = parseInt(value);
+ this._build_files_table(this._files_index_start);
+ }
+
+ _on_files_next_prev(direction)
+ {
+ var index = this._files_index_start + (this._files_index_count * direction);
+ index = Math.max(0, index);
+ this._build_files_table(index);
+ }
+
+ _search_files(needle)
+ {
+ if (needle.length == 0)
+ {
+ this._build_files_table(this._files_index_start);
+ return;
+ }
+ needle = needle.trim().toLowerCase();
+
+ this._files_table.clear(this._files_index_start);
+
+ const project = this.get_param("project");
+ const oplog = this.get_param("oplog");
+
+ var added = 0;
+ const truncate_at = this.get_param("searchmax") || 250;
+ for (const file_item of this._files_data)
+ {
+ if (!file_item.server_path.toLowerCase().includes(needle))
+ continue;
+
+ const row = this._files_table.add_row(file_item.server_path);
+
+ var base_name = file_item.server_path.split("/").pop().split("\\").pop();
+ if (this._is_null_io_hash_string(file_item.io_hash))
+ {
+ const link = row.get_cell(0).link(
+ "/" + ["prj", project, "oplog", oplog, file_item.cid].join("/")
+ );
+ link.first_child().attr("download", `${file_item.cid}_${base_name}`);
+
+ const action_tb = new Toolbar(row.get_cell(-1), true);
+ action_tb.left().add("copy-id").on_click(async (v) => {
+ await navigator.clipboard.writeText(v);
+ }, file_item.cid);
+ }
+ else
+ {
+ const link = row.get_cell(0).link(
+ "/" + ["prj", project, "oplog", oplog, file_item.io_hash].join("/")
+ );
+ link.first_child().attr("download", `${file_item.io_hash}_${base_name}`);
+
+ const action_tb = new Toolbar(row.get_cell(-1), true);
+ action_tb.left().add("copy-hash").on_click(async (v) => {
+ await navigator.clipboard.writeText(v);
+ }, file_item.io_hash);
+ }
+
+ if (++added >= truncate_at)
+ {
+ this._files_table.add_row("...truncated");
+ break;
+ }
+ }
+ }
}
diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js
new file mode 100644
index 000000000..f9e4fff33
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/hub.js
@@ -0,0 +1,122 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+import { Table } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("hub");
+
+ // Capacity
+ const stats_section = this.add_section("Capacity");
+ this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
+
+ // Modules
+ const mod_section = this.add_section("Modules");
+ this._mod_host = mod_section;
+ this._mod_table = null;
+
+ await this._update();
+ this._poll_timer = setInterval(() => this._update(), 2000);
+ }
+
+ async _update()
+ {
+ try
+ {
+ const [stats, status] = await Promise.all([
+ new Fetcher().resource("/hub/stats").json(),
+ new Fetcher().resource("/hub/status").json(),
+ ]);
+
+ this._render_capacity(stats);
+ this._render_modules(status);
+ }
+ catch (e) { /* service unavailable */ }
+ }
+
+ _render_capacity(data)
+ {
+ const grid = this._stats_grid;
+ grid.inner().innerHTML = "";
+
+ const current = data.currentInstanceCount || 0;
+ const max = data.maxInstanceCount || 0;
+ const limit = data.instanceLimit || 0;
+
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Active Modules");
+ const body = tile.tag().classify("tile-metrics");
+ this._metric(body, Friendly.sep(current), "currently provisioned", true);
+ }
+
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Peak Modules");
+ const body = tile.tag().classify("tile-metrics");
+ this._metric(body, Friendly.sep(max), "high watermark", true);
+ }
+
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Instance Limit");
+ const body = tile.tag().classify("tile-metrics");
+ this._metric(body, Friendly.sep(limit), "maximum allowed", true);
+ if (limit > 0)
+ {
+ const pct = ((current / limit) * 100).toFixed(0) + "%";
+ this._metric(body, pct, "utilization");
+ }
+ }
+ }
+
+ _render_modules(data)
+ {
+ const modules = data.modules || [];
+
+ if (this._mod_table)
+ {
+ this._mod_table.clear();
+ }
+ else
+ {
+ this._mod_table = this._mod_host.add_widget(
+ Table,
+ ["module ID", "status"],
+ Table.Flag_FitLeft|Table.Flag_PackRight
+ );
+ }
+
+ if (modules.length === 0)
+ {
+ return;
+ }
+
+ for (const m of modules)
+ {
+ this._mod_table.add_row(
+ m.moduleId || "",
+ m.provisioned ? "provisioned" : "inactive",
+ );
+ }
+ }
+
+ _metric(parent, value, label, hero = false)
+ {
+ const m = parent.tag().classify("tile-metric");
+ if (hero)
+ {
+ m.classify("tile-metric-hero");
+ }
+ m.tag().classify("metric-value").text(value);
+ m.tag().classify("metric-label").text(label);
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/info.js b/src/zenserver/frontend/html/pages/info.js
new file mode 100644
index 000000000..f92765c78
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/info.js
@@ -0,0 +1,261 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("info");
+
+ const [info, gc, services, version] = await Promise.all([
+ new Fetcher().resource("/health/info").json(),
+ new Fetcher().resource("/admin/gc").json().catch(() => null),
+ new Fetcher().resource("/api/").json().catch(() => ({})),
+ new Fetcher().resource("/health/version").param("detailed", "true").text(),
+ ]);
+
+ const section = this.add_section("Server Info");
+ const grid = section.tag().classify("grid").classify("info-tiles");
+
+ // Application
+ {
+ const tile = grid.tag().classify("card").classify("info-tile");
+ tile.tag().classify("card-title").text("Application");
+ const list = tile.tag().classify("info-props");
+
+ this._prop(list, "version", version || info.BuildVersion || "-");
+ this._prop(list, "http server", info.HttpServerClass || "-");
+ this._prop(list, "port", info.Port || "-");
+ this._prop(list, "pid", info.Pid || "-");
+ this._prop(list, "dedicated", info.IsDedicated ? "yes" : "no");
+
+ if (info.StartTimeMs)
+ {
+ const start = new Date(info.StartTimeMs);
+ const elapsed = Date.now() - info.StartTimeMs;
+ this._prop(list, "started", start.toLocaleString());
+ this._prop(list, "uptime", this._format_duration(elapsed));
+ }
+
+ this._prop(list, "data root", info.DataRoot || "-");
+ this._prop(list, "log path", info.AbsLogPath || "-");
+ }
+
+ // System
+ {
+ const tile = grid.tag().classify("card").classify("info-tile");
+ tile.tag().classify("card-title").text("System");
+ const list = tile.tag().classify("info-props");
+
+ this._prop(list, "hostname", info.Hostname || "-");
+ this._prop(list, "platform", info.Platform || "-");
+ this._prop(list, "os", info.OS || "-");
+ this._prop(list, "arch", info.Arch || "-");
+
+ const sys = info.System;
+ if (sys)
+ {
+ this._prop(list, "cpus", sys.cpu_count || "-");
+ this._prop(list, "cores", sys.core_count || "-");
+ this._prop(list, "logical processors", sys.lp_count || "-");
+ this._prop(list, "total memory", sys.total_memory_mb ? Friendly.bytes(sys.total_memory_mb * 1048576) : "-");
+ this._prop(list, "available memory", sys.avail_memory_mb ? Friendly.bytes(sys.avail_memory_mb * 1048576) : "-");
+ if (sys.uptime_seconds)
+ {
+ this._prop(list, "system uptime", this._format_duration(sys.uptime_seconds * 1000));
+ }
+ }
+ }
+
+ // Runtime Configuration
+ if (info.RuntimeConfig)
+ {
+ const tile = grid.tag().classify("card").classify("info-tile");
+ tile.tag().classify("card-title").text("Runtime Configuration");
+ const list = tile.tag().classify("info-props");
+
+ for (const key in info.RuntimeConfig)
+ {
+ this._prop(list, key, info.RuntimeConfig[key] || "-");
+ }
+ }
+
+ // Build Configuration
+ if (info.BuildConfig)
+ {
+ const tile = grid.tag().classify("card").classify("info-tile");
+ tile.tag().classify("card-title").text("Build Configuration");
+ const list = tile.tag().classify("info-props");
+
+ for (const key in info.BuildConfig)
+ {
+ this._prop(list, key, info.BuildConfig[key] ? "yes" : "no");
+ }
+ }
+
+ // Services
+ {
+ const tile = grid.tag().classify("card").classify("info-tile");
+ tile.tag().classify("card-title").text("Services");
+ const list = tile.tag().classify("info-props");
+
+ const svc_list = (services.services || []).map(s => s.base_uri).sort();
+ for (const uri of svc_list)
+ {
+ this._prop(list, uri, "registered");
+ }
+ }
+
+ // Garbage Collection
+ if (gc)
+ {
+ const tile = grid.tag().classify("card").classify("info-tile");
+ tile.tag().classify("card-title").text("Garbage Collection");
+ const list = tile.tag().classify("info-props");
+
+ this._prop(list, "status", gc.Status || "-");
+
+ if (gc.AreDiskWritesBlocked !== undefined)
+ {
+ this._prop(list, "disk writes blocked", gc.AreDiskWritesBlocked ? "yes" : "no");
+ }
+
+ if (gc.DiskSize)
+ {
+ this._prop(list, "disk size", gc.DiskSize);
+ this._prop(list, "disk used", gc.DiskUsed);
+ this._prop(list, "disk free", gc.DiskFree);
+ }
+
+ const cfg = gc.Config;
+ if (cfg)
+ {
+ this._prop(list, "gc enabled", cfg.Enabled ? "yes" : "no");
+ if (cfg.Interval)
+ {
+ this._prop(list, "interval", this._friendly_duration(cfg.Interval));
+ }
+ if (cfg.LightweightInterval)
+ {
+ this._prop(list, "lightweight interval", this._friendly_duration(cfg.LightweightInterval));
+ }
+ if (cfg.MaxCacheDuration)
+ {
+ this._prop(list, "max cache duration", this._friendly_duration(cfg.MaxCacheDuration));
+ }
+ if (cfg.MaxProjectStoreDuration)
+ {
+ this._prop(list, "max project duration", this._friendly_duration(cfg.MaxProjectStoreDuration));
+ }
+ if (cfg.MaxBuildStoreDuration)
+ {
+ this._prop(list, "max build duration", this._friendly_duration(cfg.MaxBuildStoreDuration));
+ }
+ }
+
+ if (gc.FullGC)
+ {
+ if (gc.FullGC.LastTime)
+ {
+ this._prop(list, "last full gc", this._friendly_timestamp(gc.FullGC.LastTime));
+ }
+ if (gc.FullGC.TimeToNext)
+ {
+ this._prop(list, "next full gc", this._friendly_duration(gc.FullGC.TimeToNext));
+ }
+ }
+
+ if (gc.LightweightGC)
+ {
+ if (gc.LightweightGC.LastTime)
+ {
+ this._prop(list, "last lightweight gc", this._friendly_timestamp(gc.LightweightGC.LastTime));
+ }
+ if (gc.LightweightGC.TimeToNext)
+ {
+ this._prop(list, "next lightweight gc", this._friendly_duration(gc.LightweightGC.TimeToNext));
+ }
+ }
+ }
+ }
+
+ _prop(parent, label, value)
+ {
+ const row = parent.tag().classify("info-prop");
+ row.tag().classify("info-prop-label").text(label);
+ const val = row.tag().classify("info-prop-value");
+ const str = String(value);
+ if (str.match(/^[A-Za-z]:[\\/]/) || str.startsWith("/"))
+ {
+ val.tag("a").text(str).attr("href", "vscode://" + str.replace(/\\/g, "/"));
+ }
+ else
+ {
+ val.text(str);
+ }
+ }
+
+ _friendly_timestamp(value)
+ {
+ const d = new Date(value);
+ if (isNaN(d.getTime()))
+ {
+ return String(value);
+ }
+ return d.toLocaleString(undefined, {
+ year: "numeric", month: "short", day: "numeric",
+ hour: "2-digit", minute: "2-digit", second: "2-digit",
+ });
+ }
+
+ _friendly_duration(value)
+ {
+ if (typeof value === "number")
+ {
+ return this._format_duration(value);
+ }
+
+ const str = String(value);
+ const match = str.match(/^[+-]?(?:(\d+)\.)?(\d+):(\d+):(\d+)(?:\.(\d+))?$/);
+ if (!match)
+ {
+ return str;
+ }
+
+ const days = parseInt(match[1] || "0", 10);
+ const hours = parseInt(match[2], 10);
+ const minutes = parseInt(match[3], 10);
+ const seconds = parseInt(match[4], 10);
+ const total_seconds = days * 86400 + hours * 3600 + minutes * 60 + seconds;
+
+ return this._format_duration(total_seconds * 1000);
+ }
+
+ _format_duration(ms)
+ {
+ const seconds = Math.floor(ms / 1000);
+ const minutes = Math.floor(seconds / 60);
+ const hours = Math.floor(minutes / 60);
+ const days = Math.floor(hours / 24);
+
+ if (days > 0)
+ {
+ return `${days}d ${hours % 24}h ${minutes % 60}m`;
+ }
+ if (hours > 0)
+ {
+ return `${hours}h ${minutes % 60}m`;
+ }
+ if (minutes > 0)
+ {
+ return `${minutes}m ${seconds % 60}s`;
+ }
+ return `${seconds}s`;
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/map.js b/src/zenserver/frontend/html/pages/map.js
index 58046b255..ac8f298aa 100644
--- a/src/zenserver/frontend/html/pages/map.js
+++ b/src/zenserver/frontend/html/pages/map.js
@@ -116,9 +116,9 @@ export class Page extends ZenPage
for (const name of sorted_keys)
nodes.push(new_nodes[name] / branch_size);
- var stats = Friendly.kib(branch_size);
+ var stats = Friendly.bytes(branch_size);
stats += " / ";
- stats += Friendly.kib(total_size);
+ stats += Friendly.bytes(total_size);
stats += " (";
stats += 0|((branch_size * 100) / total_size);
stats += "%)";
diff --git a/src/zenserver/frontend/html/pages/metrics.js b/src/zenserver/frontend/html/pages/metrics.js
new file mode 100644
index 000000000..e7a2eca67
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/metrics.js
@@ -0,0 +1,232 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+import { PropTable, Toolbar } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+class TemporalStat
+{
+ constructor(data, as_bytes)
+ {
+ this._data = data;
+ this._as_bytes = as_bytes;
+ }
+
+ toString()
+ {
+ const columns = [
+ /* count */ {},
+ /* rate */ {},
+ /* t */ {}, {},
+ ];
+ const data = this._data;
+ for (var key in data)
+ {
+ var out = columns[0];
+ if (key.startsWith("rate_")) out = columns[1];
+ else if (key.startsWith("t_p")) out = columns[3];
+ else if (key.startsWith("t_")) out = columns[2];
+ out[key] = data[key];
+ }
+
+ var friendly = this._as_bytes ? Friendly.bytes : Friendly.sep;
+
+ var content = "";
+ for (var i = 0; i < columns.length; ++i)
+ {
+ const column = columns[i];
+ for (var key in column)
+ {
+ var value = column[key];
+ if (i)
+ {
+ value = Friendly.sep(value, 2);
+ key = key.padStart(9);
+ content += key + ": " + value;
+ }
+ else
+ content += friendly(value);
+ content += "\r\n";
+ }
+ }
+
+ return content;
+ }
+
+ tag()
+ {
+ return "pre";
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("metrics");
+
+ const metrics_section = this.add_section("metrics");
+ const top_toolbar = metrics_section.add_widget(Toolbar);
+ const tb_right = top_toolbar.right();
+ this._refresh_label = tb_right.add("", "label");
+ this._pause_btn = tb_right.add("pause").on_click(() => this._toggle_pause());
+
+ this._paused = false;
+ this._last_refresh = Date.now();
+ this._provider_views = [];
+
+ const providers_data = await new Fetcher().resource("stats").json();
+ const providers = providers_data["providers"] || [];
+
+ const stats_list = await Promise.all(providers.map((provider) =>
+ new Fetcher()
+ .resource("stats", provider)
+ .param("cidstorestats", "true")
+ .param("cachestorestats", "true")
+ .json()
+ .then((stats) => ({ provider, stats }))
+ ));
+
+ for (const { provider, stats } of stats_list)
+ {
+ this._condense(stats);
+ this._provider_views.push(this._render_provider(provider, stats));
+ }
+
+ this._last_refresh = Date.now();
+ this._update_refresh_label();
+
+ this._timer_id = setInterval(() => this._refresh(), 5000);
+ this._tick_id = setInterval(() => this._update_refresh_label(), 1000);
+
+ document.addEventListener("visibilitychange", () => {
+ if (document.hidden)
+ this._pause_timer(false);
+ else if (!this._paused)
+ this._resume_timer();
+ });
+ }
+
+ _render_provider(provider, stats)
+ {
+ const section = this.add_section(provider);
+ const toolbar = section.add_widget(Toolbar);
+
+ toolbar.right().add("detailed →").on_click(() => {
+ window.location = "?page=stat&provider=" + provider;
+ });
+
+ const table = section.add_widget(PropTable);
+ let current_stats = stats;
+ let current_category = undefined;
+
+ const show_category = (cat) => {
+ current_category = cat;
+ table.clear();
+ table.add_object(current_stats[cat], true, 3);
+ };
+
+ var first = undefined;
+ for (var name in stats)
+ {
+ first = first || name;
+ toolbar.left().add(name).on_click(show_category, name);
+ }
+
+ if (first)
+ show_category(first);
+
+ return {
+ provider,
+ set_stats: (new_stats) => {
+ current_stats = new_stats;
+ if (current_category && current_stats[current_category])
+ show_category(current_category);
+ },
+ };
+ }
+
+ async _refresh()
+ {
+ const updates = await Promise.all(this._provider_views.map((view) =>
+ new Fetcher()
+ .resource("stats", view.provider)
+ .param("cidstorestats", "true")
+ .param("cachestorestats", "true")
+ .json()
+ .then((stats) => ({ view, stats }))
+ ));
+
+ for (const { view, stats } of updates)
+ {
+ this._condense(stats);
+ view.set_stats(stats);
+ }
+
+ this._last_refresh = Date.now();
+ this._update_refresh_label();
+ }
+
+ _update_refresh_label()
+ {
+ const elapsed = Math.floor((Date.now() - this._last_refresh) / 1000);
+ this._refresh_label.inner().textContent = "refreshed " + elapsed + "s ago";
+ }
+
+ _toggle_pause()
+ {
+ if (this._paused)
+ this._resume_timer();
+ else
+ this._pause_timer(true);
+ }
+
+ _pause_timer(user_paused=true)
+ {
+ clearInterval(this._timer_id);
+ this._timer_id = undefined;
+ if (user_paused)
+ {
+ this._paused = true;
+ this._pause_btn.inner().textContent = "resume";
+ }
+ }
+
+ _resume_timer()
+ {
+ this._paused = false;
+ this._pause_btn.inner().textContent = "pause";
+ this._timer_id = setInterval(() => this._refresh(), 5000);
+ this._refresh();
+ }
+
+ _condense(stats)
+ {
+ const impl = function(node)
+ {
+ for (var name in node)
+ {
+ const candidate = node[name];
+ if (!(candidate instanceof Object))
+ continue;
+
+ if (candidate["rate_mean"] != undefined)
+ {
+ const as_bytes = (name.indexOf("bytes") >= 0);
+ node[name] = new TemporalStat(candidate, as_bytes);
+ continue;
+ }
+
+ impl(candidate);
+ }
+ }
+
+ for (var name in stats)
+ impl(stats[name]);
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/oplog.js b/src/zenserver/frontend/html/pages/oplog.js
index 879fc4c97..fb857affb 100644
--- a/src/zenserver/frontend/html/pages/oplog.js
+++ b/src/zenserver/frontend/html/pages/oplog.js
@@ -32,7 +32,7 @@ export class Page extends ZenPage
this.set_title("oplog - " + oplog);
- var section = this.add_section(project + " - " + oplog);
+ var section = this.add_section(oplog);
oplog_info = await oplog_info;
this._index_max = oplog_info["opcount"];
@@ -81,7 +81,7 @@ export class Page extends ZenPage
const right = nav.right();
right.add(Friendly.sep(oplog_info["opcount"]));
- right.add("(" + Friendly.kib(oplog_info["totalsize"]) + ")");
+ right.add("(" + Friendly.bytes(oplog_info["totalsize"]) + ")");
right.sep();
var search_input = right.add("search:", "label").tag("input")
diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js
new file mode 100644
index 000000000..24805c722
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/orchestrator.js
@@ -0,0 +1,405 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+import { Table } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("orchestrator");
+
+ // Agents section
+ const agents_section = this._collapsible_section("Compute Agents");
+ this._agents_host = agents_section;
+ this._agents_table = null;
+
+ // Clients section
+ const clients_section = this._collapsible_section("Connected Clients");
+ this._clients_host = clients_section;
+ this._clients_table = null;
+
+ // Event history
+ const history_section = this._collapsible_section("Worker Events");
+ this._history_host = history_section;
+ this._history_table = null;
+
+ const client_history_section = this._collapsible_section("Client Events");
+ this._client_history_host = client_history_section;
+ this._client_history_table = null;
+
+ this._ws_paused = false;
+ try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) {}
+ document.addEventListener("zen-ws-toggle", (e) => {
+ this._ws_paused = e.detail.paused;
+ });
+
+ // Initial fetch
+ await this._fetch_all();
+
+ // Connect WebSocket for live updates, fall back to polling
+ this._connect_ws();
+ }
+
+ _collapsible_section(name)
+ {
+ const section = this.add_section(name);
+ const container = section._parent.inner();
+ const heading = container.firstElementChild;
+
+ heading.style.cursor = "pointer";
+ heading.style.userSelect = "none";
+
+ const indicator = document.createElement("span");
+ indicator.textContent = " \u25BC";
+ indicator.style.fontSize = "0.7em";
+ heading.appendChild(indicator);
+
+ let collapsed = false;
+ heading.addEventListener("click", (e) => {
+ if (e.target !== heading && e.target !== indicator)
+ {
+ return;
+ }
+ collapsed = !collapsed;
+ indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
+ let sibling = heading.nextElementSibling;
+ while (sibling)
+ {
+ sibling.style.display = collapsed ? "none" : "";
+ sibling = sibling.nextElementSibling;
+ }
+ });
+
+ return section;
+ }
+
+ async _fetch_all()
+ {
+ try
+ {
+ const [agents, history, clients, client_history] = await Promise.all([
+ new Fetcher().resource("/orch/agents").json(),
+ new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null),
+ new Fetcher().resource("/orch/clients").json().catch(() => null),
+ new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null),
+ ]);
+
+ this._render_agents(agents);
+ if (history)
+ {
+ this._render_history(history.events || []);
+ }
+ if (clients)
+ {
+ this._render_clients(clients.clients || []);
+ }
+ if (client_history)
+ {
+ this._render_client_history(client_history.client_events || []);
+ }
+ }
+ catch (e) { /* service unavailable */ }
+ }
+
+ _connect_ws()
+ {
+ try
+ {
+ const proto = location.protocol === "https:" ? "wss:" : "ws:";
+ const ws = new WebSocket(`${proto}//${location.host}/orch/ws`);
+
+ ws.onopen = () => {
+ if (this._poll_timer)
+ {
+ clearInterval(this._poll_timer);
+ this._poll_timer = null;
+ }
+ };
+
+ ws.onmessage = (ev) => {
+ if (this._ws_paused)
+ {
+ return;
+ }
+ try
+ {
+ const data = JSON.parse(ev.data);
+ this._render_agents(data);
+ if (data.events)
+ {
+ this._render_history(data.events);
+ }
+ if (data.clients)
+ {
+ this._render_clients(data.clients);
+ }
+ if (data.client_events)
+ {
+ this._render_client_history(data.client_events);
+ }
+ }
+ catch (e) { /* ignore parse errors */ }
+ };
+
+ ws.onclose = () => {
+ this._start_polling();
+ setTimeout(() => this._connect_ws(), 3000);
+ };
+
+ ws.onerror = () => { /* onclose will fire */ };
+ }
+ catch (e)
+ {
+ this._start_polling();
+ }
+ }
+
+ _start_polling()
+ {
+ if (!this._poll_timer)
+ {
+ this._poll_timer = setInterval(() => this._fetch_all(), 2000);
+ }
+ }
+
+ _render_agents(data)
+ {
+ const workers = data.workers || [];
+
+ if (this._agents_table)
+ {
+ this._agents_table.clear();
+ }
+ else
+ {
+ this._agents_table = this._agents_host.add_widget(
+ Table,
+ ["hostname", "CPUs", "CPU usage", "memory", "queues", "pending", "running", "completed", "traffic", "last seen"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1
+ );
+ }
+
+ if (workers.length === 0)
+ {
+ return;
+ }
+
+ let totalCpus = 0, totalWeightedCpu = 0;
+ let totalMemUsed = 0, totalMemTotal = 0;
+ let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0;
+ let totalRecv = 0, totalSent = 0;
+
+ for (const w of workers)
+ {
+ const cpus = w.cpus || 0;
+ const cpuUsage = w.cpu_usage;
+ const memUsed = w.memory_used || 0;
+ const memTotal = w.memory_total || 0;
+ const queues = w.active_queues || 0;
+ const pending = w.actions_pending || 0;
+ const running = w.actions_running || 0;
+ const completed = w.actions_completed || 0;
+ const recv = w.bytes_received || 0;
+ const sent = w.bytes_sent || 0;
+
+ totalCpus += cpus;
+ if (cpus > 0 && typeof cpuUsage === "number")
+ {
+ totalWeightedCpu += cpuUsage * cpus;
+ }
+ totalMemUsed += memUsed;
+ totalMemTotal += memTotal;
+ totalQueues += queues;
+ totalPending += pending;
+ totalRunning += running;
+ totalCompleted += completed;
+ totalRecv += recv;
+ totalSent += sent;
+
+ const hostname = w.hostname || "";
+ const row = this._agents_table.add_row(
+ hostname,
+ cpus > 0 ? Friendly.sep(cpus) : "-",
+ typeof cpuUsage === "number" ? cpuUsage.toFixed(1) + "%" : "-",
+ memTotal > 0 ? Friendly.bytes(memUsed) + " / " + Friendly.bytes(memTotal) : "-",
+ queues > 0 ? Friendly.sep(queues) : "-",
+ Friendly.sep(pending),
+ Friendly.sep(running),
+ Friendly.sep(completed),
+ this._format_traffic(recv, sent),
+ this._format_last_seen(w.dt),
+ );
+
+ // Link hostname to worker dashboard
+ if (w.uri)
+ {
+ const cell = row.get_cell(0);
+ cell.inner().textContent = "";
+ cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank");
+ }
+ }
+
+ // Total row
+ const total = this._agents_table.add_row(
+ "TOTAL",
+ Friendly.sep(totalCpus),
+ "",
+ totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-",
+ Friendly.sep(totalQueues),
+ Friendly.sep(totalPending),
+ Friendly.sep(totalRunning),
+ Friendly.sep(totalCompleted),
+ this._format_traffic(totalRecv, totalSent),
+ "",
+ );
+ total.get_cell(0).style("fontWeight", "bold");
+ }
+
+ _render_clients(clients)
+ {
+ if (this._clients_table)
+ {
+ this._clients_table.clear();
+ }
+ else
+ {
+ this._clients_table = this._clients_host.add_widget(
+ Table,
+ ["client ID", "hostname", "address", "last seen"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable, -1
+ );
+ }
+
+ for (const c of clients)
+ {
+ this._clients_table.add_row(
+ c.id || "",
+ c.hostname || "",
+ c.address || "",
+ this._format_last_seen(c.dt),
+ );
+ }
+ }
+
+ _render_history(events)
+ {
+ if (this._history_table)
+ {
+ this._history_table.clear();
+ }
+ else
+ {
+ this._history_table = this._history_host.add_widget(
+ Table,
+ ["time", "event", "worker", "hostname"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable, -1
+ );
+ }
+
+ for (const evt of events)
+ {
+ this._history_table.add_row(
+ this._format_timestamp(evt.ts),
+ evt.type || "",
+ evt.worker_id || "",
+ evt.hostname || "",
+ );
+ }
+ }
+
+ _render_client_history(events)
+ {
+ if (this._client_history_table)
+ {
+ this._client_history_table.clear();
+ }
+ else
+ {
+ this._client_history_table = this._client_history_host.add_widget(
+ Table,
+ ["time", "event", "client", "hostname"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable, -1
+ );
+ }
+
+ for (const evt of events)
+ {
+ this._client_history_table.add_row(
+ this._format_timestamp(evt.ts),
+ evt.type || "",
+ evt.client_id || "",
+ evt.hostname || "",
+ );
+ }
+ }
+
+ _metric(parent, value, label, hero = false)
+ {
+ const m = parent.tag().classify("tile-metric");
+ if (hero)
+ {
+ m.classify("tile-metric-hero");
+ }
+ m.tag().classify("metric-value").text(value);
+ m.tag().classify("metric-label").text(label);
+ }
+
+ _format_last_seen(dtMs)
+ {
+ if (dtMs == null)
+ {
+ return "-";
+ }
+ const seconds = Math.floor(dtMs / 1000);
+ if (seconds < 60)
+ {
+ return seconds + "s ago";
+ }
+ const minutes = Math.floor(seconds / 60);
+ if (minutes < 60)
+ {
+ return minutes + "m " + (seconds % 60) + "s ago";
+ }
+ const hours = Math.floor(minutes / 60);
+ return hours + "h " + (minutes % 60) + "m ago";
+ }
+
+ _format_traffic(recv, sent)
+ {
+ if (!recv && !sent)
+ {
+ return "-";
+ }
+ return Friendly.bytes(recv) + " / " + Friendly.bytes(sent);
+ }
+
+ _format_timestamp(ts)
+ {
+ if (!ts)
+ {
+ return "-";
+ }
+ let date;
+ if (typeof ts === "number")
+ {
+ // .NET-style ticks: convert to Unix ms
+ const unixMs = (ts - 621355968000000000) / 10000;
+ date = new Date(unixMs);
+ }
+ else
+ {
+ date = new Date(ts);
+ }
+ if (isNaN(date.getTime()))
+ {
+ return "-";
+ }
+ return date.toLocaleTimeString();
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js
index 9a9541904..dd8032c28 100644
--- a/src/zenserver/frontend/html/pages/page.js
+++ b/src/zenserver/frontend/html/pages/page.js
@@ -3,6 +3,7 @@
"use strict";
import { WidgetHost } from "../util/widgets.js"
+import { Fetcher } from "../util/fetcher.js"
////////////////////////////////////////////////////////////////////////////////
export class PageBase extends WidgetHost
@@ -63,31 +64,85 @@ export class ZenPage extends PageBase
super(parent, ...args);
super.set_title("zen");
this.add_branding(parent);
+ this.add_service_nav(parent);
this.generate_crumbs();
}
add_branding(parent)
{
- var root = parent.tag().id("branding");
-
- const zen_store = root.tag("pre").id("logo").text(
- "_________ _______ __\n" +
- "\\____ /___ ___ / ___// |__ ___ ______ ____\n" +
- " / __/ __ \\ / \\ \\___ \\\\_ __// \\\\_ \\/ __ \\\n" +
- " / \\ __// | \\/ \\| | ( - )| |\\/\\ __/\n" +
- "/______/\\___/\\__|__/\\______/|__| \\___/ |__| \\___|"
- );
- zen_store.tag().id("go_home").on_click(() => window.location.search = "");
-
- root.tag("img").attr("src", "favicon.ico").id("ue_logo");
-
- /*
- _________ _______ __
- \____ /___ ___ / ___// |__ ___ ______ ____
- / __/ __ \ / \ \___ \\_ __// \\_ \/ __ \
- / \ __// | \/ \| | ( - )| |\/\ __/
- /______/\___/\__|__/\______/|__| \___/ |__| \___|
- */
+ var banner = parent.tag("zen-banner");
+ banner.attr("subtitle", "SERVER");
+ banner.attr("tagline", "Local Storage Service");
+ banner.attr("logo-src", "favicon.ico");
+ banner.attr("load", "0");
+
+ this._banner = banner;
+ this._poll_status();
+ }
+
+ async _poll_status()
+ {
+ try
+ {
+ var cbo = await new Fetcher().resource("/status/status").cbo();
+ if (cbo)
+ {
+ var obj = cbo.as_object();
+
+ var hostname = obj.find("hostname");
+ if (hostname)
+ {
+ this._banner.attr("tagline", "Local Storage Service \u2014 " + hostname.as_value());
+ }
+
+ var cpu = obj.find("cpuUsagePercent");
+ if (cpu)
+ {
+ this._banner.attr("load", cpu.as_value().toFixed(1));
+ }
+ }
+ }
+ catch (e) { console.warn("status poll:", e); }
+
+ setTimeout(() => this._poll_status(), 2000);
+ }
+
+ add_service_nav(parent)
+ {
+ const nav = parent.tag().id("service_nav");
+
+ // Map service base URIs to dashboard links, this table is also used to detemine
+ // which links to show based on the services that are currently registered.
+
+ const service_dashboards = [
+ { base_uri: "/compute/", label: "Compute", href: "/dashboard/?page=compute" },
+ { base_uri: "/orch/", label: "Orchestrator", href: "/dashboard/?page=orchestrator" },
+ { base_uri: "/hub/", label: "Hub", href: "/dashboard/?page=hub" },
+ ];
+
+ nav.tag("a").text("Home").attr("href", "/dashboard/");
+
+ nav.tag("a").text("Sessions").attr("href", "/dashboard/?page=sessions");
+ nav.tag("a").text("Cache").attr("href", "/dashboard/?page=cache");
+ nav.tag("a").text("Projects").attr("href", "/dashboard/?page=projects");
+ this._info_link = nav.tag("a").text("Info").attr("href", "/dashboard/?page=info");
+
+ new Fetcher().resource("/api/").json().then((data) => {
+ const services = data.services || [];
+ const uris = new Set(services.map(s => s.base_uri));
+
+ const links = service_dashboards.filter(d => uris.has(d.base_uri));
+
+ // Insert service links before the Info link
+ const info_elem = this._info_link.inner();
+ for (const link of links)
+ {
+ const a = document.createElement("a");
+ a.textContent = link.label;
+ a.href = link.href;
+ info_elem.parentNode.insertBefore(a, info_elem);
+ }
+ }).catch(() => {});
}
set_title(...args)
@@ -97,7 +152,7 @@ export class ZenPage extends PageBase
generate_crumbs()
{
- const auto_name = this.get_param("page") || "start";
+ var auto_name = this.get_param("page") || "start";
if (auto_name == "start")
return;
@@ -114,15 +169,30 @@ export class ZenPage extends PageBase
var project = this.get_param("project");
if (project != undefined)
{
+ auto_name = project;
var oplog = this.get_param("oplog");
if (oplog != undefined)
{
- new_crumb("project", `?page=project&project=${project}`);
- if (this.get_param("opkey"))
- new_crumb("oplog", `?page=oplog&project=${project}&oplog=${oplog}`);
+ new_crumb(auto_name, `?page=project&project=${project}`);
+ auto_name = oplog;
+ var opkey = this.get_param("opkey")
+ if (opkey != undefined)
+ {
+ new_crumb(auto_name, `?page=oplog&project=${project}&oplog=${oplog}`);
+ auto_name = opkey.split("/").pop().split("\\").pop();
+
+ // Check if we're viewing cook artifacts
+ var page = this.get_param("page");
+ var hash = this.get_param("hash");
+ if (hash != undefined && page == "cookartifacts")
+ {
+ new_crumb(auto_name, `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey}`);
+ auto_name = "cook artifacts";
+ }
+ }
}
}
- new_crumb(auto_name.toLowerCase());
+ new_crumb(auto_name);
}
}
diff --git a/src/zenserver/frontend/html/pages/project.js b/src/zenserver/frontend/html/pages/project.js
index 42ae30c8c..3a7a45527 100644
--- a/src/zenserver/frontend/html/pages/project.js
+++ b/src/zenserver/frontend/html/pages/project.js
@@ -59,7 +59,7 @@ export class Page extends ZenPage
info = await info;
row.get_cell(1).text(info["markerpath"]);
- row.get_cell(2).text(Friendly.kib(info["totalsize"]));
+ row.get_cell(2).text(Friendly.bytes(info["totalsize"]));
row.get_cell(3).text(Friendly.sep(info["opcount"]));
row.get_cell(4).text(info["expired"]);
}
diff --git a/src/zenserver/frontend/html/pages/projects.js b/src/zenserver/frontend/html/pages/projects.js
new file mode 100644
index 000000000..9c1e519d4
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/projects.js
@@ -0,0 +1,447 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Friendly } from "../util/friendly.js"
+import { Modal } from "../util/modal.js"
+import { Table, Toolbar } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("projects");
+
+ // Project Service Stats
+ const stats_section = this._collapsible_section("Project Service Stats");
+ stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => {
+ window.open("/stats/prj.yaml", "_blank");
+ });
+ this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
+
+ const stats = await new Fetcher().resource("stats", "prj").json();
+ if (stats)
+ {
+ this._render_stats(stats);
+ }
+
+ this._connect_stats_ws();
+
+ // Projects list
+ var section = this._collapsible_section("Projects");
+
+ section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all());
+
+ var columns = [
+ "name",
+ "project dir",
+ "engine dir",
+ "oplogs",
+ "actions",
+ ];
+
+ this._project_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric);
+
+ var projects = await new Fetcher().resource("/prj/list").json();
+ projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0));
+
+ for (const project of projects)
+ {
+ var row = this._project_table.add_row(
+ "",
+ "",
+ "",
+ "",
+ );
+
+ var cell = row.get_cell(0);
+ cell.tag().text(project.Id).on_click(() => this.view_project(project.Id));
+
+ if (project.ProjectRootDir)
+ {
+ row.get_cell(1).tag("a").text(project.ProjectRootDir)
+ .attr("href", "vscode://" + project.ProjectRootDir.replace(/\\/g, "/"));
+ }
+ if (project.EngineRootDir)
+ {
+ row.get_cell(2).tag("a").text(project.EngineRootDir)
+ .attr("href", "vscode://" + project.EngineRootDir.replace(/\\/g, "/"));
+ }
+
+ cell = row.get_cell(-1);
+ const action_tb = new Toolbar(cell, true).left();
+ action_tb.add("view").on_click(() => this.view_project(project.Id));
+ action_tb.add("drop").on_click(() => this.drop_project(project.Id));
+
+ row.attr("zs_name", project.Id);
+
+ // Fetch project details to get oplog count
+ new Fetcher().resource("prj", project.Id).json().then((info) => {
+ const oplogs = info["oplogs"] || [];
+ row.get_cell(3).text(Friendly.sep(oplogs.length)).style("textAlign", "right");
+ // Right-align the corresponding header cell
+ const header = this._project_table._element.firstElementChild;
+ if (header && header.children[4])
+ {
+ header.children[4].style.textAlign = "right";
+ }
+ }).catch(() => {});
+ }
+
+ // Project detail area (inside projects section so it collapses together)
+ this._project_host = section;
+ this._project_container = null;
+ this._selected_project = null;
+
+ // Restore project from URL if present
+ const prj_param = this.get_param("project");
+ if (prj_param)
+ {
+ this.view_project(prj_param);
+ }
+ }
+
+ _collapsible_section(name)
+ {
+ const section = this.add_section(name);
+ const container = section._parent.inner();
+ const heading = container.firstElementChild;
+
+ heading.style.cursor = "pointer";
+ heading.style.userSelect = "none";
+
+ const indicator = document.createElement("span");
+ indicator.textContent = " \u25BC";
+ indicator.style.fontSize = "0.7em";
+ heading.appendChild(indicator);
+
+ let collapsed = false;
+ heading.addEventListener("click", (e) => {
+ if (e.target !== heading && e.target !== indicator)
+ {
+ return;
+ }
+ collapsed = !collapsed;
+ indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
+ let sibling = heading.nextElementSibling;
+ while (sibling)
+ {
+ sibling.style.display = collapsed ? "none" : "";
+ sibling = sibling.nextElementSibling;
+ }
+ });
+
+ return section;
+ }
+
+ _clear_param(name)
+ {
+ this._params.delete(name);
+ const url = new URL(window.location);
+ url.searchParams.delete(name);
+ history.replaceState(null, "", url);
+ }
+
+ _connect_stats_ws()
+ {
+ try
+ {
+ const proto = location.protocol === "https:" ? "wss:" : "ws:";
+ const ws = new WebSocket(`${proto}//${location.host}/stats`);
+
+ try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) { this._ws_paused = false; }
+ document.addEventListener("zen-ws-toggle", (e) => {
+ this._ws_paused = e.detail.paused;
+ });
+
+ ws.onmessage = (ev) => {
+ if (this._ws_paused)
+ {
+ return;
+ }
+ try
+ {
+ const all_stats = JSON.parse(ev.data);
+ const stats = all_stats["prj"];
+ if (stats)
+ {
+ this._render_stats(stats);
+ }
+ }
+ catch (e) { /* ignore parse errors */ }
+ };
+
+ ws.onclose = () => { this._stats_ws = null; };
+ ws.onerror = () => { ws.close(); };
+
+ this._stats_ws = ws;
+ }
+ catch (e) { /* WebSocket not available */ }
+ }
+
+ _render_stats(stats)
+ {
+ const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
+ const grid = this._stats_grid;
+
+ grid.inner().innerHTML = "";
+
+ // HTTP Requests tile
+ {
+ const req = safe(stats, "requests");
+ if (req)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("HTTP Requests");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const reqData = req.requests || req;
+ this._metric(left, Friendly.sep(safe(stats, "store.requestcount") || 0), "total requests", true);
+ if (reqData.rate_mean > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)");
+ }
+ if (reqData.rate_1 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)");
+ }
+ const badRequests = safe(stats, "store.badrequestcount") || 0;
+ this._metric(left, Friendly.sep(badRequests), "bad requests");
+
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true);
+ if (reqData.t_p75)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p75), "p75");
+ }
+ if (reqData.t_p95)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p95), "p95");
+ }
+ if (reqData.t_p99)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p99), "p99");
+ }
+ }
+ }
+
+ // Store Operations tile
+ {
+ const store = safe(stats, "store");
+ if (store)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Store Operations");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const proj = store.project || {};
+ this._metric(left, Friendly.sep(proj.readcount || 0), "project reads", true);
+ this._metric(left, Friendly.sep(proj.writecount || 0), "project writes");
+ this._metric(left, Friendly.sep(proj.deletecount || 0), "project deletes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const oplog = store.oplog || {};
+ this._metric(right, Friendly.sep(oplog.readcount || 0), "oplog reads", true);
+ this._metric(right, Friendly.sep(oplog.writecount || 0), "oplog writes");
+ this._metric(right, Friendly.sep(oplog.deletecount || 0), "oplog deletes");
+ }
+ }
+
+ // Op & Chunk tile
+ {
+ const store = safe(stats, "store");
+ if (store)
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Ops & Chunks");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const op = store.op || {};
+ const opTotal = (op.hitcount || 0) + (op.misscount || 0);
+ const opRatio = opTotal > 0 ? (((op.hitcount || 0) / opTotal) * 100).toFixed(1) + "%" : "-";
+ this._metric(left, opRatio, "op hit ratio", true);
+ this._metric(left, Friendly.sep(op.hitcount || 0), "op hits");
+ this._metric(left, Friendly.sep(op.misscount || 0), "op misses");
+ this._metric(left, Friendly.sep(op.writecount || 0), "op writes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const chunk = store.chunk || {};
+ const chunkTotal = (chunk.hitcount || 0) + (chunk.misscount || 0);
+ const chunkRatio = chunkTotal > 0 ? (((chunk.hitcount || 0) / chunkTotal) * 100).toFixed(1) + "%" : "-";
+ this._metric(right, chunkRatio, "chunk hit ratio", true);
+ this._metric(right, Friendly.sep(chunk.hitcount || 0), "chunk hits");
+ this._metric(right, Friendly.sep(chunk.misscount || 0), "chunk misses");
+ this._metric(right, Friendly.sep(chunk.writecount || 0), "chunk writes");
+ }
+ }
+
+ // Storage tile
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Storage");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, safe(stats, "store.size.disk") != null ? Friendly.bytes(safe(stats, "store.size.disk")) : "-", "store disk", true);
+ this._metric(left, safe(stats, "store.size.memory") != null ? Friendly.bytes(safe(stats, "store.size.memory")) : "-", "store memory");
+
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, safe(stats, "cid.size.total") != null ? Friendly.bytes(safe(stats, "cid.size.total")) : "-", "cid total", true);
+ this._metric(right, safe(stats, "cid.size.tiny") != null ? Friendly.bytes(safe(stats, "cid.size.tiny")) : "-", "cid tiny");
+ this._metric(right, safe(stats, "cid.size.small") != null ? Friendly.bytes(safe(stats, "cid.size.small")) : "-", "cid small");
+ this._metric(right, safe(stats, "cid.size.large") != null ? Friendly.bytes(safe(stats, "cid.size.large")) : "-", "cid large");
+ }
+ }
+
+ _metric(parent, value, label, hero = false)
+ {
+ const m = parent.tag().classify("tile-metric");
+ if (hero)
+ {
+ m.classify("tile-metric-hero");
+ }
+ m.tag().classify("metric-value").text(value);
+ m.tag().classify("metric-label").text(label);
+ }
+
+ async view_project(project_id)
+ {
+ // Toggle off if already selected
+ if (this._selected_project === project_id)
+ {
+ this._selected_project = null;
+ this._clear_project_detail();
+ this._clear_param("project");
+ return;
+ }
+
+ this._selected_project = project_id;
+ this._clear_project_detail();
+ this.set_param("project", project_id);
+
+ const info = await new Fetcher().resource("prj", project_id).json();
+ if (this._selected_project !== project_id)
+ {
+ return;
+ }
+
+ const section = this._project_host.add_section(project_id);
+ this._project_container = section;
+
+ // Oplogs table
+ const oplog_section = section.add_section("Oplogs");
+ const oplog_table = oplog_section.add_widget(
+ Table,
+ ["name", "marker", "size", "ops", "expired", "actions"],
+ Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric
+ );
+
+ let totalSize = 0, totalOps = 0;
+ const total_row = oplog_table.add_row("TOTAL");
+ total_row.get_cell(0).style("fontWeight", "bold");
+ total_row.get_cell(2).style("textAlign", "right").style("fontWeight", "bold");
+ total_row.get_cell(3).style("textAlign", "right").style("fontWeight", "bold");
+
+ // Right-align header for numeric columns (size, ops)
+ const header = oplog_table._element.firstElementChild;
+ for (let i = 3; i < header.children.length - 1; i++)
+ {
+ header.children[i].style.textAlign = "right";
+ }
+
+ for (const oplog of info["oplogs"] || [])
+ {
+ const name = oplog["id"];
+ const row = oplog_table.add_row("");
+
+ var cell = row.get_cell(0);
+ cell.tag().text(name).link("", {
+ "page": "oplog",
+ "project": project_id,
+ "oplog": name,
+ });
+
+ cell = row.get_cell(-1);
+ const action_tb = new Toolbar(cell, true).left();
+ action_tb.add("list").link("", { "page": "oplog", "project": project_id, "oplog": name });
+ action_tb.add("tree").link("", { "page": "tree", "project": project_id, "oplog": name });
+ action_tb.add("drop").on_click(() => this.drop_oplog(project_id, name));
+
+ new Fetcher().resource("prj", project_id, "oplog", name).json().then((data) => {
+ row.get_cell(1).text(data["markerpath"]);
+ row.get_cell(2).text(Friendly.bytes(data["totalsize"])).style("textAlign", "right");
+ row.get_cell(3).text(Friendly.sep(data["opcount"])).style("textAlign", "right");
+ row.get_cell(4).text(data["expired"]);
+
+ totalSize += data["totalsize"] || 0;
+ totalOps += data["opcount"] || 0;
+ total_row.get_cell(2).text(Friendly.bytes(totalSize)).style("textAlign", "right").style("fontWeight", "bold");
+ total_row.get_cell(3).text(Friendly.sep(totalOps)).style("textAlign", "right").style("fontWeight", "bold");
+ }).catch(() => {});
+ }
+ }
+
+ _clear_project_detail()
+ {
+ if (this._project_container)
+ {
+ this._project_container._parent.inner().remove();
+ this._project_container = null;
+ }
+ }
+
+ drop_oplog(project_id, oplog_id)
+ {
+ const drop = async () => {
+ await new Fetcher().resource("prj", project_id, "oplog", oplog_id).delete();
+ // Refresh the project view
+ this._selected_project = null;
+ this._clear_project_detail();
+ this.view_project(project_id);
+ };
+
+ new Modal()
+ .title("Confirmation")
+ .message(`Drop oplog '${oplog_id}'?`)
+ .option("Yes", () => drop())
+ .option("No");
+ }
+
+ drop_project(project_id)
+ {
+ const drop = async () => {
+ await new Fetcher().resource("prj", project_id).delete();
+ this.reload();
+ };
+
+ new Modal()
+ .title("Confirmation")
+ .message(`Drop project '${project_id}'?`)
+ .option("Yes", () => drop())
+ .option("No");
+ }
+
+ async drop_all()
+ {
+ const drop = async () => {
+ for (const row of this._project_table)
+ {
+ const project_id = row.attr("zs_name");
+ await new Fetcher().resource("prj", project_id).delete();
+ }
+ this.reload();
+ };
+
+ new Modal()
+ .title("Confirmation")
+ .message("Drop every project?")
+ .option("Yes", () => drop())
+ .option("No");
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/sessions.js b/src/zenserver/frontend/html/pages/sessions.js
new file mode 100644
index 000000000..95533aa96
--- /dev/null
+++ b/src/zenserver/frontend/html/pages/sessions.js
@@ -0,0 +1,61 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+"use strict";
+
+import { ZenPage } from "./page.js"
+import { Fetcher } from "../util/fetcher.js"
+import { Table } from "../util/widgets.js"
+
+////////////////////////////////////////////////////////////////////////////////
+export class Page extends ZenPage
+{
+ async main()
+ {
+ this.set_title("sessions");
+
+ const data = await new Fetcher().resource("/sessions/").json();
+ const sessions = data.sessions || [];
+
+ const section = this.add_section("Sessions");
+
+ if (sessions.length === 0)
+ {
+ section.tag().classify("empty-state").text("No active sessions.");
+ return;
+ }
+
+ const columns = [
+ "id",
+ "created",
+ "updated",
+ "metadata",
+ ];
+ const table = section.add_widget(Table, columns, Table.Flag_FitLeft);
+
+ for (const session of sessions)
+ {
+ const created = session.created_at ? new Date(session.created_at).toLocaleString() : "-";
+ const updated = session.updated_at ? new Date(session.updated_at).toLocaleString() : "-";
+ const meta = this._format_metadata(session.metadata);
+
+ const row = table.add_row(
+ session.id || "-",
+ created,
+ updated,
+ meta,
+ );
+ }
+ }
+
+ _format_metadata(metadata)
+ {
+ if (!metadata || Object.keys(metadata).length === 0)
+ {
+ return "-";
+ }
+
+ return Object.entries(metadata)
+ .map(([k, v]) => `${k}: ${v}`)
+ .join(", ");
+ }
+}
diff --git a/src/zenserver/frontend/html/pages/start.js b/src/zenserver/frontend/html/pages/start.js
index 4c8789431..3a68a725d 100644
--- a/src/zenserver/frontend/html/pages/start.js
+++ b/src/zenserver/frontend/html/pages/start.js
@@ -13,109 +13,117 @@ export class Page extends ZenPage
{
async main()
{
+ // Discover which services are available
+ const api_data = await new Fetcher().resource("/api/").json();
+ const available = new Set((api_data.services || []).map(s => s.base_uri));
+
// project list
- var section = this.add_section("projects");
+ var project_table = null;
+ if (available.has("/prj/"))
+ {
+ var section = this.add_section("Cooked Projects");
- section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("projects"));
+ section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("projects"));
- var columns = [
- "name",
- "project_dir",
- "engine_dir",
- "actions",
- ];
- var project_table = section.add_widget(Table, columns);
+ var columns = [
+ "name",
+ "project_dir",
+ "engine_dir",
+ "actions",
+ ];
+ project_table = section.add_widget(Table, columns);
- for (const project of await new Fetcher().resource("/prj/list").json())
- {
- var row = project_table.add_row(
- "",
- project.ProjectRootDir,
- project.EngineRootDir,
- );
+ var projects = await new Fetcher().resource("/prj/list").json();
+ projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0));
+ projects = projects.slice(0, 25);
+ projects.sort((a, b) => a.Id.localeCompare(b.Id));
- var cell = row.get_cell(0);
- cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id);
+ for (const project of projects)
+ {
+ var row = project_table.add_row(
+ "",
+ project.ProjectRootDir,
+ project.EngineRootDir,
+ );
+
+ var cell = row.get_cell(0);
+ cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id);
- var cell = row.get_cell(-1);
- var action_tb = new Toolbar(cell, true);
- action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id);
- action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id);
+ var cell = row.get_cell(-1);
+ var action_tb = new Toolbar(cell, true);
+ action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id);
+ action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id);
- row.attr("zs_name", project.Id);
+ row.attr("zs_name", project.Id);
+ }
}
// cache
- var section = this.add_section("z$");
-
- section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$"));
-
- columns = [
- "namespace",
- "dir",
- "buckets",
- "entries",
- "size disk",
- "size mem",
- "actions",
- ]
- var zcache_info = new Fetcher().resource("/z$/").json();
- const cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight);
- for (const namespace of (await zcache_info)["Namespaces"])
+ var cache_table = null;
+ if (available.has("/z$/"))
{
- new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => {
- const row = cache_table.add_row(
- "",
- data["Configuration"]["RootDir"],
- data["Buckets"].length,
- data["EntryCount"],
- Friendly.kib(data["StorageSize"].DiskSize),
- Friendly.kib(data["StorageSize"].MemorySize)
- );
- var cell = row.get_cell(0);
- cell.tag().text(namespace).on_click(() => this.view_zcache(namespace));
- row.get_cell(1).tag().text(namespace);
+ var section = this.add_section("Cache");
- cell = row.get_cell(-1);
- const action_tb = new Toolbar(cell, true);
- action_tb.left().add("view").on_click(() => this.view_zcache(namespace));
- action_tb.left().add("drop").on_click(() => this.drop_zcache(namespace));
+ section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$"));
- row.attr("zs_name", namespace);
- });
+ var columns = [
+ "namespace",
+ "dir",
+ "buckets",
+ "entries",
+ "size disk",
+ "size mem",
+ "actions",
+ ];
+ var zcache_info = await new Fetcher().resource("/z$/").json();
+ cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight);
+ for (const namespace of zcache_info["Namespaces"] || [])
+ {
+ new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => {
+ const row = cache_table.add_row(
+ "",
+ data["Configuration"]["RootDir"],
+ data["Buckets"].length,
+ data["EntryCount"],
+ Friendly.bytes(data["StorageSize"].DiskSize),
+ Friendly.bytes(data["StorageSize"].MemorySize)
+ );
+ var cell = row.get_cell(0);
+ cell.tag().text(namespace).on_click(() => this.view_zcache(namespace));
+ row.get_cell(1).tag().text(namespace);
+
+ cell = row.get_cell(-1);
+ const action_tb = new Toolbar(cell, true);
+ action_tb.left().add("view").on_click(() => this.view_zcache(namespace));
+ action_tb.left().add("drop").on_click(() => this.drop_zcache(namespace));
+
+ row.attr("zs_name", namespace);
+ });
+ }
}
- // stats
+ // stats tiles
const safe_lookup = (obj, path, pretty=undefined) => {
const ret = path.split(".").reduce((a,b) => a && a[b], obj);
- if (ret === undefined) return "-";
+ if (ret === undefined) return undefined;
return pretty ? pretty(ret) : ret;
};
- section = this.add_section("stats");
- columns = [
- "name",
- "req count",
- "size disk",
- "size mem",
- "cid total",
- ];
- const stats_table = section.add_widget(Table, columns, Table.Flag_PackRight);
- var providers = new Fetcher().resource("stats").json();
- for (var provider of (await providers)["providers"])
- {
- var stats = await new Fetcher().resource("stats", provider).json();
- var size_stat = (stats.store || stats.cache);
- var values = [
- "",
- safe_lookup(stats, "requests.count"),
- safe_lookup(size_stat, "size.disk", Friendly.kib),
- safe_lookup(size_stat, "size.memory", Friendly.kib),
- safe_lookup(stats, "cid.size.total"),
- ];
- row = stats_table.add_row(...values);
- row.get_cell(0).tag().text(provider).on_click((x) => this.view_stat(x), provider);
- }
+ var section = this.add_section("Stats");
+ section.tag().classify("dropall").text("metrics dashboard →").on_click(() => {
+ window.location = "?page=metrics";
+ });
+
+ var providers_data = await new Fetcher().resource("stats").json();
+ var provider_list = providers_data["providers"] || [];
+ var all_stats = {};
+ await Promise.all(provider_list.map(async (provider) => {
+ all_stats[provider] = await new Fetcher().resource("stats", provider).json();
+ }));
+
+ this._stats_grid = section.tag().classify("grid").classify("stats-tiles");
+ this._safe_lookup = safe_lookup;
+ this._render_stats(all_stats);
// version
var ver_tag = this.tag().id("version");
@@ -125,6 +133,159 @@ export class Page extends ZenPage
this._project_table = project_table;
this._cache_table = cache_table;
+
+ // WebSocket for live stats updates
+ this._connect_stats_ws();
+ }
+
+ _connect_stats_ws()
+ {
+ try
+ {
+ const proto = location.protocol === "https:" ? "wss:" : "ws:";
+ const ws = new WebSocket(`${proto}//${location.host}/stats`);
+
+ try { this._ws_paused = localStorage.getItem("zen-ws-paused") === "true"; } catch (e) { this._ws_paused = false; }
+ document.addEventListener("zen-ws-toggle", (e) => {
+ this._ws_paused = e.detail.paused;
+ });
+
+ ws.onmessage = (ev) => {
+ if (this._ws_paused)
+ {
+ return;
+ }
+ try
+ {
+ const all_stats = JSON.parse(ev.data);
+ this._render_stats(all_stats);
+ }
+ catch (e) { /* ignore parse errors */ }
+ };
+
+ ws.onclose = () => { this._stats_ws = null; };
+ ws.onerror = () => { ws.close(); };
+
+ this._stats_ws = ws;
+ }
+ catch (e) { /* WebSocket not available */ }
+ }
+
+ _render_stats(all_stats)
+ {
+ const grid = this._stats_grid;
+ const safe_lookup = this._safe_lookup;
+
+ // Clear existing tiles
+ grid.inner().innerHTML = "";
+
+ // HTTP tile — aggregate request stats across all providers
+ {
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("HTTP");
+ const columns = tile.tag().classify("tile-columns");
+
+ // Left column: request stats
+ const left = columns.tag().classify("tile-metrics");
+
+ let total_requests = 0;
+ let total_rate = 0;
+ for (const p in all_stats)
+ {
+ total_requests += (safe_lookup(all_stats[p], "requests.count") || 0);
+ total_rate += (safe_lookup(all_stats[p], "requests.rate_1") || 0);
+ }
+
+ this._add_tile_metric(left, Friendly.sep(total_requests), "total requests", true);
+ if (total_rate > 0)
+ this._add_tile_metric(left, Friendly.sep(total_rate, 1) + "/s", "req/sec (1m)");
+
+ // Right column: websocket stats
+ const ws = all_stats["http"] ? (all_stats["http"]["websockets"] || {}) : {};
+ const right = columns.tag().classify("tile-metrics");
+
+ this._add_tile_metric(right, Friendly.sep(ws.active_connections || 0), "ws connections", true);
+ const ws_frames = (ws.frames_received || 0) + (ws.frames_sent || 0);
+ if (ws_frames > 0)
+ this._add_tile_metric(right, Friendly.sep(ws_frames), "ws frames");
+ const ws_bytes = (ws.bytes_received || 0) + (ws.bytes_sent || 0);
+ if (ws_bytes > 0)
+ this._add_tile_metric(right, Friendly.bytes(ws_bytes), "ws traffic");
+
+ tile.on_click(() => { window.location = "?page=metrics"; });
+ }
+
+ // Cache tile (z$)
+ if (all_stats["z$"])
+ {
+ const s = all_stats["z$"];
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Cache");
+ const body = tile.tag().classify("tile-metrics");
+
+ const hits = safe_lookup(s, "cache.hits") || 0;
+ const misses = safe_lookup(s, "cache.misses") || 0;
+ const ratio = (hits + misses) > 0 ? ((hits / (hits + misses)) * 100).toFixed(1) + "%" : "-";
+
+ this._add_tile_metric(body, ratio, "hit ratio", true);
+ this._add_tile_metric(body, safe_lookup(s, "cache.size.disk", Friendly.bytes) || "-", "disk");
+ this._add_tile_metric(body, safe_lookup(s, "cache.size.memory", Friendly.bytes) || "-", "memory");
+
+ tile.on_click(() => { window.location = "?page=stat&provider=z$"; });
+ }
+
+ // Project Store tile (prj)
+ if (all_stats["prj"])
+ {
+ const s = all_stats["prj"];
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Project Store");
+ const body = tile.tag().classify("tile-metrics");
+
+ this._add_tile_metric(body, safe_lookup(s, "requests.count", Friendly.sep) || "-", "requests", true);
+ this._add_tile_metric(body, safe_lookup(s, "store.size.disk", Friendly.bytes) || "-", "disk");
+
+ tile.on_click(() => { window.location = "?page=stat&provider=prj"; });
+ }
+
+ // Build Store tile (builds)
+ if (all_stats["builds"])
+ {
+ const s = all_stats["builds"];
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Build Store");
+ const body = tile.tag().classify("tile-metrics");
+
+ this._add_tile_metric(body, safe_lookup(s, "requests.count", Friendly.sep) || "-", "requests", true);
+ this._add_tile_metric(body, safe_lookup(s, "store.size.disk", Friendly.bytes) || "-", "disk");
+
+ tile.on_click(() => { window.location = "?page=stat&provider=builds"; });
+ }
+
+ // Workspace tile (ws)
+ if (all_stats["ws"])
+ {
+ const s = all_stats["ws"];
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Workspace");
+ const body = tile.tag().classify("tile-metrics");
+
+ this._add_tile_metric(body, safe_lookup(s, "requests.count", Friendly.sep) || "-", "requests", true);
+ this._add_tile_metric(body, safe_lookup(s, "workspaces.filescount", Friendly.sep) || "-", "files");
+
+ tile.on_click(() => { window.location = "?page=stat&provider=ws"; });
+ }
+ }
+
+ _add_tile_metric(parent, value, label, hero=false)
+ {
+ const m = parent.tag().classify("tile-metric");
+ if (hero)
+ {
+ m.classify("tile-metric-hero");
+ }
+ m.tag().classify("metric-value").text(value);
+ m.tag().classify("metric-label").text(label);
}
view_stat(provider)
diff --git a/src/zenserver/frontend/html/pages/stat.js b/src/zenserver/frontend/html/pages/stat.js
index d6c7fa8e8..4f020ac5e 100644
--- a/src/zenserver/frontend/html/pages/stat.js
+++ b/src/zenserver/frontend/html/pages/stat.js
@@ -33,7 +33,7 @@ class TemporalStat
out[key] = data[key];
}
- var friendly = this._as_bytes ? Friendly.kib : Friendly.sep;
+ var friendly = this._as_bytes ? Friendly.bytes : Friendly.sep;
var content = "";
for (var i = 0; i < columns.length; ++i)
diff --git a/src/zenserver/frontend/html/pages/tree.js b/src/zenserver/frontend/html/pages/tree.js
index 08a578492..b5fece5a3 100644
--- a/src/zenserver/frontend/html/pages/tree.js
+++ b/src/zenserver/frontend/html/pages/tree.js
@@ -106,7 +106,7 @@ export class Page extends ZenPage
for (var i = 0; i < 2; ++i)
{
- const size = Friendly.kib(new_nodes[name][i]);
+ const size = Friendly.bytes(new_nodes[name][i]);
info.tag().text(size);
}
diff --git a/src/zenserver/frontend/html/pages/zcache.js b/src/zenserver/frontend/html/pages/zcache.js
index 974893b21..d8bdc892a 100644
--- a/src/zenserver/frontend/html/pages/zcache.js
+++ b/src/zenserver/frontend/html/pages/zcache.js
@@ -27,8 +27,8 @@ export class Page extends ZenPage
cfg_table.add_object(info["Configuration"], true);
- storage_table.add_property("disk", Friendly.kib(info["StorageSize"]["DiskSize"]));
- storage_table.add_property("mem", Friendly.kib(info["StorageSize"]["MemorySize"]));
+ storage_table.add_property("disk", Friendly.bytes(info["StorageSize"]["DiskSize"]));
+ storage_table.add_property("mem", Friendly.bytes(info["StorageSize"]["MemorySize"]));
storage_table.add_property("entries", Friendly.sep(info["EntryCount"]));
var column_names = ["name", "disk", "mem", "entries", "actions"];
@@ -41,8 +41,8 @@ export class Page extends ZenPage
{
const row = bucket_table.add_row(bucket);
new Fetcher().resource(`/z$/${namespace}/${bucket}`).json().then((data) => {
- row.get_cell(1).text(Friendly.kib(data["StorageSize"]["DiskSize"]));
- row.get_cell(2).text(Friendly.kib(data["StorageSize"]["MemorySize"]));
+ row.get_cell(1).text(Friendly.bytes(data["StorageSize"]["DiskSize"]));
+ row.get_cell(2).text(Friendly.bytes(data["StorageSize"]["MemorySize"]));
row.get_cell(3).text(Friendly.sep(data["DiskEntryCount"]));
const cell = row.get_cell(-1);
diff --git a/src/zenserver/frontend/html/theme.js b/src/zenserver/frontend/html/theme.js
new file mode 100644
index 000000000..52ca116ab
--- /dev/null
+++ b/src/zenserver/frontend/html/theme.js
@@ -0,0 +1,116 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+// Theme toggle: cycles system → light → dark → system.
+// Persists choice in localStorage. Applies data-theme attribute on <html>.
+
+(function() {
+ var KEY = 'zen-theme';
+
+ function getStored() {
+ try { return localStorage.getItem(KEY); } catch (e) { return null; }
+ }
+
+ function setStored(value) {
+ try {
+ if (value) localStorage.setItem(KEY, value);
+ else localStorage.removeItem(KEY);
+ } catch (e) {}
+ }
+
+ function apply(theme) {
+ if (theme)
+ document.documentElement.setAttribute('data-theme', theme);
+ else
+ document.documentElement.removeAttribute('data-theme');
+ }
+
+ function getEffective(stored) {
+ if (stored) return stored;
+ return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
+ }
+
+ // Apply stored preference immediately (before paint)
+ var stored = getStored();
+ apply(stored);
+
+ // Create toggle button once DOM is ready
+ function createToggle() {
+ var btn = document.createElement('button');
+ btn.id = 'zen_theme_toggle';
+ btn.title = 'Toggle theme';
+
+ function updateIcon() {
+ var effective = getEffective(getStored());
+ // Show sun in dark mode (click to go light), moon in light mode (click to go dark)
+ btn.textContent = effective === 'dark' ? '\u2600' : '\u263E';
+
+ var isManual = getStored() != null;
+ btn.title = isManual
+ ? 'Theme: ' + effective + ' (click to change, double-click for system)'
+ : 'Theme: system (click to change)';
+ }
+
+ btn.addEventListener('click', function() {
+ var current = getStored();
+ var effective = getEffective(current);
+ // Toggle to the opposite
+ var next = effective === 'dark' ? 'light' : 'dark';
+ setStored(next);
+ apply(next);
+ updateIcon();
+ });
+
+ btn.addEventListener('dblclick', function(e) {
+ e.preventDefault();
+ // Reset to system preference
+ setStored(null);
+ apply(null);
+ updateIcon();
+ });
+
+ // Update icon when system preference changes
+ window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', function() {
+ if (!getStored()) updateIcon();
+ });
+
+ updateIcon();
+ document.body.appendChild(btn);
+
+ // WebSocket pause/play toggle
+ var WS_KEY = 'zen-ws-paused';
+ var wsBtn = document.createElement('button');
+ wsBtn.id = 'zen_ws_toggle';
+
+ var initialPaused = false;
+ try { initialPaused = localStorage.getItem(WS_KEY) === 'true'; } catch (e) {}
+
+ function updateWsIcon(paused) {
+ wsBtn.dataset.paused = paused ? 'true' : 'false';
+ wsBtn.textContent = paused ? '\u25B6' : '\u23F8';
+ wsBtn.title = paused ? 'Resume live updates' : 'Pause live updates';
+ }
+
+ updateWsIcon(initialPaused);
+
+ // Fire initial event so pages pick up persisted state
+ document.addEventListener('DOMContentLoaded', function() {
+ if (initialPaused) {
+ document.dispatchEvent(new CustomEvent('zen-ws-toggle', { detail: { paused: true } }));
+ }
+ });
+
+ wsBtn.addEventListener('click', function() {
+ var paused = wsBtn.dataset.paused !== 'true';
+ try { localStorage.setItem(WS_KEY, paused ? 'true' : 'false'); } catch (e) {}
+ updateWsIcon(paused);
+ document.dispatchEvent(new CustomEvent('zen-ws-toggle', { detail: { paused: paused } }));
+ });
+
+ document.body.appendChild(wsBtn);
+ }
+
+ if (document.readyState === 'loading')
+ document.addEventListener('DOMContentLoaded', createToggle);
+ else
+ createToggle();
+})();
diff --git a/src/zenserver/frontend/html/util/compactbinary.js b/src/zenserver/frontend/html/util/compactbinary.js
index 90e4249f6..415fa4be8 100644
--- a/src/zenserver/frontend/html/util/compactbinary.js
+++ b/src/zenserver/frontend/html/util/compactbinary.js
@@ -310,8 +310,8 @@ CbFieldView.prototype.as_value = function(int_type=BigInt)
case CbFieldType.IntegerPositive: return VarInt.read_uint(this._data_view, int_type)[0];
case CbFieldType.IntegerNegative: return VarInt.read_int(this._data_view, int_type)[0];
- case CbFieldType.Float32: return new DataView(this._data_view.subarray(0, 4)).getFloat32(0, false);
- case CbFieldType.Float64: return new DataView(this._data_view.subarray(0, 8)).getFloat64(0, false);
+ case CbFieldType.Float32: { const s = this._data_view; return new DataView(s.buffer, s.byteOffset, 4).getFloat32(0, false); }
+ case CbFieldType.Float64: { const s = this._data_view; return new DataView(s.buffer, s.byteOffset, 8).getFloat64(0, false); }
case CbFieldType.BoolFalse: return false;
case CbFieldType.BoolTrue: return true;
diff --git a/src/zenserver/frontend/html/util/friendly.js b/src/zenserver/frontend/html/util/friendly.js
index a15252faf..5d4586165 100644
--- a/src/zenserver/frontend/html/util/friendly.js
+++ b/src/zenserver/frontend/html/util/friendly.js
@@ -20,4 +20,25 @@ export class Friendly
static kib(x, p=0) { return Friendly.sep((BigInt(x) + 1023n) / (1n << 10n)|0n, p) + " KiB"; }
static mib(x, p=1) { return Friendly.sep( BigInt(x) / (1n << 20n), p) + " MiB"; }
static gib(x, p=2) { return Friendly.sep( BigInt(x) / (1n << 30n), p) + " GiB"; }
+
+ static duration(s)
+ {
+ const v = Number(s);
+ if (v >= 1) return Friendly.sep(v, 2) + " s";
+ if (v >= 0.001) return Friendly.sep(v * 1000, 2) + " ms";
+ if (v >= 0.000001) return Friendly.sep(v * 1000000, 1) + " µs";
+ return Friendly.sep(v * 1000000000, 0) + " ns";
+ }
+
+ static bytes(x)
+ {
+ const v = BigInt(Math.trunc(Number(x)));
+ if (v >= (1n << 60n)) return Friendly.sep(Number(v) / Number(1n << 60n), 2) + " EiB";
+ if (v >= (1n << 50n)) return Friendly.sep(Number(v) / Number(1n << 50n), 2) + " PiB";
+ if (v >= (1n << 40n)) return Friendly.sep(Number(v) / Number(1n << 40n), 2) + " TiB";
+ if (v >= (1n << 30n)) return Friendly.sep(Number(v) / Number(1n << 30n), 2) + " GiB";
+ if (v >= (1n << 20n)) return Friendly.sep(Number(v) / Number(1n << 20n), 1) + " MiB";
+ if (v >= (1n << 10n)) return Friendly.sep(Number(v) / Number(1n << 10n), 0) + " KiB";
+ return Friendly.sep(Number(v), 0) + " B";
+ }
}
diff --git a/src/zenserver/frontend/html/util/widgets.js b/src/zenserver/frontend/html/util/widgets.js
index 32a3f4d28..2964f92f2 100644
--- a/src/zenserver/frontend/html/util/widgets.js
+++ b/src/zenserver/frontend/html/util/widgets.js
@@ -54,6 +54,8 @@ export class Table extends Widget
static Flag_PackRight = 1 << 1;
static Flag_BiasLeft = 1 << 2;
static Flag_FitLeft = 1 << 3;
+ static Flag_Sortable = 1 << 4;
+ static Flag_AlignNumeric = 1 << 5;
constructor(parent, column_names, flags=Table.Flag_EvenSpacing, index_base=0)
{
@@ -81,11 +83,108 @@ export class Table extends Widget
root.style("gridTemplateColumns", column_style);
- this._add_row(column_names, false);
+ this._header_row = this._add_row(column_names, false);
this._index = index_base;
this._num_columns = column_names.length;
this._rows = [];
+ this._flags = flags;
+ this._sort_column = -1;
+ this._sort_ascending = true;
+
+ if (flags & Table.Flag_Sortable)
+ {
+ this._init_sortable();
+ }
+ }
+
+ _init_sortable()
+ {
+ const header_elem = this._element.firstElementChild;
+ if (!header_elem)
+ {
+ return;
+ }
+
+ const cells = header_elem.children;
+ for (let i = 0; i < cells.length; i++)
+ {
+ const cell = cells[i];
+ cell.style.cursor = "pointer";
+ cell.style.userSelect = "none";
+ cell.addEventListener("click", () => this._sort_by(i));
+ }
+ }
+
+ _sort_by(column_index)
+ {
+ if (this._sort_column === column_index)
+ {
+ this._sort_ascending = !this._sort_ascending;
+ }
+ else
+ {
+ this._sort_column = column_index;
+ this._sort_ascending = true;
+ }
+
+ // Update header indicators
+ const header_elem = this._element.firstElementChild;
+ for (const cell of header_elem.children)
+ {
+ const text = cell.textContent.replace(/ [▲▼]$/, "");
+ cell.textContent = text;
+ }
+ const active_cell = header_elem.children[column_index];
+ active_cell.textContent += this._sort_ascending ? " ▲" : " ▼";
+
+ // Sort rows by comparing cell text content
+ const dir = this._sort_ascending ? 1 : -1;
+ const unit_multipliers = { "B": 1, "KiB": 1024, "MiB": 1048576, "GiB": 1073741824, "TiB": 1099511627776, "PiB": 1125899906842624, "EiB": 1152921504606846976 };
+ const parse_sortable = (text) => {
+ // Try byte units first (e.g. "1,234 KiB", "1.5 GiB")
+ const byte_match = text.match(/^([\d,.]+)\s*(B|[KMGTPE]iB)/);
+ if (byte_match)
+ {
+ const num = parseFloat(byte_match[1].replace(/,/g, ""));
+ const mult = unit_multipliers[byte_match[2]] || 1;
+ return num * mult;
+ }
+ // Try percentage (e.g. "95.5%")
+ const pct_match = text.match(/^([\d,.]+)%/);
+ if (pct_match)
+ {
+ return parseFloat(pct_match[1].replace(/,/g, ""));
+ }
+ // Try plain number (possibly with commas/separators)
+ const num = parseFloat(text.replace(/,/g, ""));
+ if (!isNaN(num))
+ {
+ return num;
+ }
+ return null;
+ };
+ this._rows.sort((a, b) => {
+ const aElem = a.inner().children[column_index];
+ const bElem = b.inner().children[column_index];
+ const aText = aElem ? aElem.textContent : "";
+ const bText = bElem ? bElem.textContent : "";
+
+ const aNum = parse_sortable(aText);
+ const bNum = parse_sortable(bText);
+
+ if (aNum !== null && bNum !== null)
+ {
+ return (aNum - bNum) * dir;
+ }
+ return aText.localeCompare(bText) * dir;
+ });
+
+ // Re-order DOM elements
+ for (const row of this._rows)
+ {
+ this._element.appendChild(row.inner());
+ }
}
*[Symbol.iterator]()
@@ -121,6 +220,18 @@ export class Table extends Widget
ret.push(new TableCell(leaf, row));
}
+ if ((this._flags & Table.Flag_AlignNumeric) && indexed)
+ {
+ for (const c of ret)
+ {
+ const t = c.inner().textContent;
+ if (t && /^\d/.test(t))
+ {
+ c.style("textAlign", "right");
+ }
+ }
+ }
+
if (this._index >= 0)
ret.shift();
@@ -131,9 +242,34 @@ export class Table extends Widget
{
var row = this._add_row(args);
this._rows.push(row);
+
+ if ((this._flags & Table.Flag_AlignNumeric) && this._rows.length === 1)
+ {
+ this._align_header();
+ }
+
return row;
}
+ _align_header()
+ {
+ const first_row = this._rows[0];
+ if (!first_row)
+ {
+ return;
+ }
+ const header_elem = this._element.firstElementChild;
+ const header_cells = header_elem.children;
+ const data_cells = first_row.inner().children;
+ for (let i = 0; i < data_cells.length && i < header_cells.length; i++)
+ {
+ if (data_cells[i].style.textAlign === "right")
+ {
+ header_cells[i].style.textAlign = "right";
+ }
+ }
+ }
+
clear(index=0)
{
const elem = this._element;
diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css
index cc53c0519..a968aecab 100644
--- a/src/zenserver/frontend/html/zen.css
+++ b/src/zenserver/frontend/html/zen.css
@@ -2,66 +2,202 @@
/* theme -------------------------------------------------------------------- */
+/* system preference (default) */
@media (prefers-color-scheme: light) {
:root {
- --theme_g0: #000;
- --theme_g4: #fff;
- --theme_g1: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 45%);
- --theme_g2: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 80%);
- --theme_g3: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 96%);
-
- --theme_p0: #069;
- --theme_p4: hsl(210deg 40% 94%);
+ --theme_g0: #1f2328;
+ --theme_g1: #656d76;
+ --theme_g2: #d0d7de;
+ --theme_g3: #f6f8fa;
+ --theme_g4: #ffffff;
+
+ --theme_p0: #0969da;
+ --theme_p4: #ddf4ff;
--theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%);
--theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%);
--theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%);
--theme_ln: var(--theme_p0);
- --theme_er: #fcc;
+ --theme_er: #ffebe9;
+
+ --theme_ok: #1a7f37;
+ --theme_warn: #9a6700;
+ --theme_fail: #cf222e;
+
+ --theme_bright: #1f2328;
+ --theme_faint: #6e7781;
+ --theme_border_subtle: #d8dee4;
}
}
@media (prefers-color-scheme: dark) {
:root {
- --theme_g0: #ddd;
- --theme_g4: #222;
- --theme_g1: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 35%);
- --theme_g2: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 65%);
- --theme_g3: color-mix(in oklab, var(--theme_g0), var(--theme_g4) 88%);
-
- --theme_p0: #479;
- --theme_p4: #333;
+ --theme_g0: #c9d1d9;
+ --theme_g1: #8b949e;
+ --theme_g2: #30363d;
+ --theme_g3: #161b22;
+ --theme_g4: #0d1117;
+
+ --theme_p0: #58a6ff;
+ --theme_p4: #1c2128;
--theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%);
--theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%);
--theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%);
- --theme_ln: #feb;
- --theme_er: #622;
+ --theme_ln: #58a6ff;
+ --theme_er: #1c1c1c;
+
+ --theme_ok: #3fb950;
+ --theme_warn: #d29922;
+ --theme_fail: #f85149;
+
+ --theme_bright: #f0f6fc;
+ --theme_faint: #6e7681;
+ --theme_border_subtle: #21262d;
}
}
+/* manual overrides (higher specificity than media queries) */
+:root[data-theme="light"] {
+ --theme_g0: #1f2328;
+ --theme_g1: #656d76;
+ --theme_g2: #d0d7de;
+ --theme_g3: #f6f8fa;
+ --theme_g4: #ffffff;
+
+ --theme_p0: #0969da;
+ --theme_p4: #ddf4ff;
+ --theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%);
+ --theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%);
+ --theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%);
+
+ --theme_ln: var(--theme_p0);
+ --theme_er: #ffebe9;
+
+ --theme_ok: #1a7f37;
+ --theme_warn: #9a6700;
+ --theme_fail: #cf222e;
+
+ --theme_bright: #1f2328;
+ --theme_faint: #6e7781;
+ --theme_border_subtle: #d8dee4;
+}
+
+:root[data-theme="dark"] {
+ --theme_g0: #c9d1d9;
+ --theme_g1: #8b949e;
+ --theme_g2: #30363d;
+ --theme_g3: #161b22;
+ --theme_g4: #0d1117;
+
+ --theme_p0: #58a6ff;
+ --theme_p4: #1c2128;
+ --theme_p1: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 35%);
+ --theme_p2: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 60%);
+ --theme_p3: color-mix(in oklab, var(--theme_p0), var(--theme_p4) 85%);
+
+ --theme_ln: #58a6ff;
+ --theme_er: #1c1c1c;
+
+ --theme_ok: #3fb950;
+ --theme_warn: #d29922;
+ --theme_fail: #f85149;
+
+ --theme_bright: #f0f6fc;
+ --theme_faint: #6e7681;
+ --theme_border_subtle: #21262d;
+}
+
+/* theme toggle ------------------------------------------------------------- */
+
+#zen_ws_toggle {
+ position: fixed;
+ top: 16px;
+ right: 60px;
+ z-index: 10;
+ width: 36px;
+ height: 36px;
+ border-radius: 6px;
+ border: 1px solid var(--theme_g2);
+ background: var(--theme_g3);
+ color: var(--theme_g1);
+ cursor: pointer;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font-size: 18px;
+ line-height: 1;
+ transition: color 0.15s, background 0.15s, border-color 0.15s;
+ user-select: none;
+}
+
+#zen_ws_toggle:hover {
+ color: var(--theme_g0);
+ background: var(--theme_p4);
+ border-color: var(--theme_g1);
+}
+
+#zen_theme_toggle {
+ position: fixed;
+ top: 16px;
+ right: 16px;
+ z-index: 10;
+ width: 36px;
+ height: 36px;
+ border-radius: 6px;
+ border: 1px solid var(--theme_g2);
+ background: var(--theme_g3);
+ color: var(--theme_g1);
+ cursor: pointer;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font-size: 18px;
+ line-height: 1;
+ transition: color 0.15s, background 0.15s, border-color 0.15s;
+ user-select: none;
+}
+
+#zen_theme_toggle:hover {
+ color: var(--theme_g0);
+ background: var(--theme_p4);
+ border-color: var(--theme_g1);
+}
+
/* page --------------------------------------------------------------------- */
-body, input {
- font-family: consolas, monospace;
- font-size: 11pt;
+body, input, button {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
+ font-size: 14px;
}
body {
overflow-y: scroll;
margin: 0;
+ padding: 20px;
background-color: var(--theme_g4);
color: var(--theme_g0);
}
-pre {
- margin: 0;
+pre, code {
+ font-family: 'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace;
+ font-size: 13px;
+ margin: 0;
}
input {
color: var(--theme_g0);
background-color: var(--theme_g3);
border: 1px solid var(--theme_g2);
+ border-radius: 4px;
+ padding: 4px 8px;
+}
+
+button {
+ color: var(--theme_g0);
+ background: transparent;
+ border: none;
+ cursor: pointer;
}
* {
@@ -69,17 +205,44 @@ input {
}
#container {
- max-width: 130em;
- min-width: 80em;
+ max-width: 1400px;
margin: auto;
> div {
- margin: 0.0em 2.2em 0.0em 2.2em;
padding-top: 1.0em;
padding-bottom: 1.5em;
}
}
+/* service nav -------------------------------------------------------------- */
+
+#service_nav {
+ display: flex;
+ align-items: center;
+ gap: 4px;
+ margin-bottom: 16px;
+ padding: 4px;
+ background-color: var(--theme_g3);
+ border: 1px solid var(--theme_g2);
+ border-radius: 6px;
+
+ a {
+ padding: 6px 14px;
+ border-radius: 4px;
+ font-size: 13px;
+ font-weight: 500;
+ color: var(--theme_g1);
+ text-decoration: none;
+ transition: color 0.15s, background 0.15s;
+ }
+
+ a:hover {
+ background-color: var(--theme_p4);
+ color: var(--theme_g0);
+ text-decoration: none;
+ }
+}
+
/* links -------------------------------------------------------------------- */
a {
@@ -103,28 +266,37 @@ a {
}
h1 {
- font-size: 1.5em;
+ font-size: 20px;
+ font-weight: 600;
width: 100%;
+ color: var(--theme_bright);
border-bottom: 1px solid var(--theme_g2);
+ padding-bottom: 0.4em;
+ margin-bottom: 16px;
}
h2 {
- font-size: 1.25em;
- margin-bottom: 0.5em;
+ font-size: 16px;
+ font-weight: 600;
+ margin-bottom: 12px;
}
h3 {
- font-size: 1.1em;
+ font-size: 14px;
+ font-weight: 600;
margin: 0em;
- padding: 0.4em;
- background-color: var(--theme_p4);
- border-left: 5px solid var(--theme_p2);
- font-weight: normal;
+ padding: 8px 12px;
+ background-color: var(--theme_g3);
+ border: 1px solid var(--theme_g2);
+ border-radius: 6px 6px 0 0;
+ color: var(--theme_g1);
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
}
- margin-bottom: 3em;
+ margin-bottom: 2em;
> *:not(h1) {
- margin-left: 2em;
+ margin-left: 0;
}
}
@@ -134,23 +306,36 @@ a {
.zen_table {
display: grid;
border: 1px solid var(--theme_g2);
- border-left-style: none;
+ border-radius: 6px;
+ overflow: hidden;
margin-bottom: 1.2em;
+ font-size: 13px;
> div {
display: contents;
}
- > div:nth-of-type(odd) {
+ > div:nth-of-type(odd) > div {
+ background-color: var(--theme_g4);
+ }
+
+ > div:nth-of-type(even) > div {
background-color: var(--theme_g3);
}
> div:first-of-type {
- font-weight: bold;
- background-color: var(--theme_p3);
+ font-weight: 600;
+ > div {
+ background-color: var(--theme_g3);
+ color: var(--theme_g1);
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+ font-size: 11px;
+ border-bottom: 1px solid var(--theme_g2);
+ }
}
- > div:hover {
+ > div:not(:first-of-type):hover > div {
background-color: var(--theme_p4);
}
@@ -160,16 +345,37 @@ a {
}
> div > div {
- padding: 0.3em;
- padding-left: 0.75em;
- padding-right: 0.75em;
+ padding: 8px 12px;
align-content: center;
- border-left: 1px solid var(--theme_g2);
+ border-left: 1px solid var(--theme_border_subtle);
overflow: auto;
overflow-wrap: break-word;
- background-color: inherit;
white-space: pre-wrap;
}
+
+ > div > div:first-child {
+ border-left: none;
+ }
+}
+
+/* expandable cell ---------------------------------------------------------- */
+
+.zen_expand_icon {
+ cursor: pointer;
+ margin-right: 0.5em;
+ color: var(--theme_g1);
+ font-weight: bold;
+ user-select: none;
+}
+
+.zen_expand_icon:hover {
+ color: var(--theme_ln);
+}
+
+.zen_data_text {
+ user-select: text;
+ font-family: 'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace;
+ font-size: 13px;
}
/* toolbar ------------------------------------------------------------------ */
@@ -178,6 +384,7 @@ a {
display: flex;
margin-top: 0.5em;
margin-bottom: 0.6em;
+ font-size: 13px;
> div {
display: flex;
@@ -225,15 +432,16 @@ a {
z-index: -1;
top: 0;
left: 0;
- width: 100%;
+ width: 100%;
height: 100%;
background: var(--theme_g0);
opacity: 0.4;
}
> div {
- border-radius: 0.5em;
- background-color: var(--theme_g4);
+ border-radius: 6px;
+ background-color: var(--theme_g3);
+ border: 1px solid var(--theme_g2);
opacity: 1.0;
width: 35em;
padding: 0em 2em 2em 2em;
@@ -244,10 +452,11 @@ a {
}
.zen_modal_title {
- font-size: 1.2em;
+ font-size: 16px;
+ font-weight: 600;
border-bottom: 1px solid var(--theme_g2);
padding: 1.2em 0em 0.5em 0em;
- color: var(--theme_g1);
+ color: var(--theme_bright);
}
.zen_modal_buttons {
@@ -257,16 +466,19 @@ a {
> div {
margin: 0em 1em 0em 1em;
- padding: 1em;
+ padding: 10px 16px;
align-content: center;
- border-radius: 0.3em;
- background-color: var(--theme_p3);
+ border-radius: 6px;
+ background-color: var(--theme_p4);
+ border: 1px solid var(--theme_g2);
width: 6em;
cursor: pointer;
+ font-weight: 500;
+ transition: background 0.15s;
}
> div:hover {
- background-color: var(--theme_p4);
+ background-color: var(--theme_p3);
}
}
@@ -284,15 +496,18 @@ a {
top: 0;
left: 0;
width: 100%;
- height: 0.5em;
+ height: 4px;
+ border-radius: 2px;
+ overflow: hidden;
> div:first-of-type {
/* label */
padding: 0.3em;
- padding-top: 0.8em;
- background-color: var(--theme_p4);
+ padding-top: 8px;
+ background-color: var(--theme_g3);
width: max-content;
- font-size: 0.8em;
+ font-size: 12px;
+ color: var(--theme_g1);
}
> div:last-of-type {
@@ -302,7 +517,8 @@ a {
left: 0;
width: 0%;
height: 100%;
- background-color: var(--theme_p1);
+ background-color: var(--theme_p0);
+ transition: width 0.3s ease;
}
> div:nth-of-type(2) {
@@ -312,7 +528,7 @@ a {
left: 0;
width: 100%;
height: 100%;
- background-color: var(--theme_p3);
+ background-color: var(--theme_g3);
}
}
@@ -321,53 +537,25 @@ a {
#crumbs {
display: flex;
position: relative;
- top: -1em;
+ top: -0.5em;
+ font-size: 13px;
+ color: var(--theme_g1);
> div {
padding-right: 0.5em;
}
> div:nth-child(odd)::after {
- content: ":";
- font-weight: bolder;
- color: var(--theme_p2);
+ content: "/";
+ color: var(--theme_g2);
+ padding-left: 0.5em;
}
}
-/* branding ----------------------------------------------------------------- */
-
-#branding {
- font-size: 10pt;
- font-weight: bolder;
- margin-bottom: 2.6em;
- position: relative;
+/* banner ------------------------------------------------------------------- */
- #logo {
- width: min-content;
- margin: auto;
- user-select: none;
- position: relative;
-
- #go_home {
- width: 100%;
- height: 100%;
- position: absolute;
- top: 0;
- left: 0;
- }
- }
-
- #logo:hover {
- filter: drop-shadow(0 0.15em 0.1em var(--theme_p2));
- }
-
- #ue_logo {
- position: absolute;
- top: 1em;
- right: 0;
- width: 5em;
- margin: auto;
- }
+zen-banner {
+ margin-bottom: 24px;
}
/* error -------------------------------------------------------------------- */
@@ -378,26 +566,23 @@ a {
z-index: 1;
color: var(--theme_g0);
background-color: var(--theme_er);
- padding: 1.0em 2em 2em 2em;
+ padding: 12px 20px 16px 20px;
width: 100%;
- border-top: 1px solid var(--theme_g0);
+ border-top: 1px solid var(--theme_g2);
display: flex;
+ gap: 16px;
+ align-items: center;
+ font-size: 13px;
> div:nth-child(1) {
- font-size: 2.5em;
- font-weight: bolder;
- font-family: serif;
- transform: rotate(-13deg);
- color: var(--theme_p0);
- }
-
- > div:nth-child(2) {
- margin-left: 2em;
+ font-size: 24px;
+ font-weight: bold;
+ color: var(--theme_fail);
}
> div:nth-child(2) > pre:nth-child(2) {
- margin-top: 0.5em;
- font-size: 0.8em;
+ margin-top: 4px;
+ font-size: 12px;
color: var(--theme_g1);
}
}
@@ -409,18 +594,144 @@ a {
min-width: 15%;
}
+/* sections ----------------------------------------------------------------- */
+
+.zen_sector {
+ position: relative;
+}
+
+.dropall {
+ position: absolute;
+ top: 16px;
+ right: 0;
+ font-size: 12px;
+ margin: 0;
+}
+
+/* stats tiles -------------------------------------------------------------- */
+
+.stats-tiles {
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
+}
+
+.stats-tile {
+ cursor: pointer;
+ transition: border-color 0.15s, background 0.15s;
+}
+
+.stats-tile:hover {
+ border-color: var(--theme_p0);
+ background: var(--theme_p4);
+}
+
+.stats-tile-detailed {
+ position: relative;
+}
+
+.stats-tile-detailed::after {
+ content: "details \203A";
+ position: absolute;
+ bottom: 12px;
+ right: 20px;
+ font-size: 11px;
+ color: var(--theme_g1);
+ opacity: 0.6;
+ transition: opacity 0.15s;
+}
+
+.stats-tile-detailed:hover::after {
+ opacity: 1;
+ color: var(--theme_p0);
+}
+
+.stats-tile-selected {
+ border-color: var(--theme_p0);
+ background: var(--theme_p4);
+ box-shadow: 0 0 0 1px var(--theme_p0);
+}
+
+.stats-tile-selected::after {
+ content: "details \2039";
+ opacity: 1;
+ color: var(--theme_p0);
+}
+
+.tile-metrics {
+ display: flex;
+ flex-direction: column;
+ gap: 12px;
+}
+
+.tile-columns {
+ display: flex;
+ gap: 24px;
+}
+
+.tile-columns > .tile-metrics {
+ flex: 1;
+ min-width: 0;
+}
+
+.tile-metric .metric-value {
+ font-size: 16px;
+}
+
+.tile-metric-hero .metric-value {
+ font-size: 28px;
+}
+
/* start -------------------------------------------------------------------- */
#start {
- .dropall {
- text-align: right;
- font-size: 0.85em;
- margin: -0.5em 0 0.5em 0;
- }
#version {
- color: var(--theme_g1);
+ color: var(--theme_faint);
text-align: center;
- font-size: 0.85em;
+ font-size: 12px;
+ margin-top: 24px;
+ }
+}
+
+/* info --------------------------------------------------------------------- */
+
+#info {
+ .info-tiles {
+ grid-template-columns: repeat(auto-fit, minmax(320px, 1fr));
+ }
+
+ .info-tile {
+ overflow: hidden;
+ }
+
+ .info-props {
+ display: flex;
+ flex-direction: column;
+ gap: 1px;
+ font-size: 13px;
+ }
+
+ .info-prop {
+ display: flex;
+ gap: 12px;
+ padding: 4px 0;
+ border-bottom: 1px solid var(--theme_border_subtle);
+ }
+
+ .info-prop:last-child {
+ border-bottom: none;
+ }
+
+ .info-prop-label {
+ color: var(--theme_g1);
+ min-width: 140px;
+ flex-shrink: 0;
+ text-transform: capitalize;
+ }
+
+ .info-prop-value {
+ color: var(--theme_bright);
+ word-break: break-all;
+ margin-left: auto;
+ text-align: right;
}
}
@@ -437,6 +748,8 @@ a {
/* tree --------------------------------------------------------------------- */
#tree {
+ font-size: 13px;
+
#tree_root > ul {
margin-left: 0em;
}
@@ -448,29 +761,33 @@ a {
li > div {
display: flex;
border-bottom: 1px solid transparent;
- padding-left: 0.3em;
- padding-right: 0.3em;
+ padding: 4px 6px;
+ border-radius: 4px;
}
li > div > div[active] {
text-transform: uppercase;
+ color: var(--theme_p0);
+ font-weight: 600;
}
li > div > div:nth-last-child(3) {
margin-left: auto;
}
li > div > div:nth-last-child(-n + 3) {
- font-size: 0.8em;
+ font-size: 12px;
width: 10em;
text-align: right;
+ color: var(--theme_g1);
+ font-family: 'SF Mono', 'Cascadia Mono', Consolas, monospace;
}
li > div > div:nth-last-child(1) {
width: 6em;
}
li > div:hover {
background-color: var(--theme_p4);
- border-bottom: 1px solid var(--theme_g2);
+ border-bottom: 1px solid var(--theme_border_subtle);
}
li a {
- font-weight: bolder;
+ font-weight: 600;
}
li::marker {
content: "+";
@@ -503,3 +820,262 @@ html:has(#map) {
}
}
}
+
+/* ========================================================================== */
+/* Shared classes for compute / dashboard pages */
+/* ========================================================================== */
+
+/* cards -------------------------------------------------------------------- */
+
+.card {
+ background: var(--theme_g3);
+ border: 1px solid var(--theme_g2);
+ border-radius: 6px;
+ padding: 20px;
+}
+
+.card-title {
+ font-size: 14px;
+ font-weight: 600;
+ color: var(--theme_g1);
+ margin-bottom: 12px;
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+}
+
+/* grid --------------------------------------------------------------------- */
+
+.grid {
+ display: grid;
+ grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
+ gap: 20px;
+ margin-bottom: 24px;
+}
+
+/* metrics ------------------------------------------------------------------ */
+
+.metric-value {
+ font-size: 36px;
+ font-weight: 600;
+ color: var(--theme_bright);
+ line-height: 1;
+}
+
+.metric-label {
+ font-size: 12px;
+ color: var(--theme_g1);
+ margin-top: 4px;
+}
+
+/* section titles ----------------------------------------------------------- */
+
+.section-title {
+ font-size: 20px;
+ font-weight: 600;
+ margin-bottom: 16px;
+ color: var(--theme_bright);
+}
+
+/* html tables (compute pages) ---------------------------------------------- */
+
+table {
+ width: 100%;
+ border-collapse: collapse;
+ font-size: 13px;
+}
+
+th {
+ text-align: left;
+ color: var(--theme_g1);
+ padding: 8px 12px;
+ border-bottom: 1px solid var(--theme_g2);
+ font-weight: 600;
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+ font-size: 11px;
+}
+
+td {
+ padding: 8px 12px;
+ border-bottom: 1px solid var(--theme_border_subtle);
+ color: var(--theme_g0);
+}
+
+tr:last-child td {
+ border-bottom: none;
+}
+
+.total-row td {
+ border-top: 2px solid var(--theme_g2);
+ font-weight: 600;
+ color: var(--theme_bright);
+}
+
+/* status badges ------------------------------------------------------------ */
+
+.status-badge {
+ display: inline-block;
+ padding: 2px 8px;
+ border-radius: 4px;
+ font-size: 11px;
+ font-weight: 600;
+}
+
+.status-badge.active,
+.status-badge.success {
+ background: color-mix(in srgb, var(--theme_ok) 15%, transparent);
+ color: var(--theme_ok);
+}
+
+.status-badge.inactive {
+ background: color-mix(in srgb, var(--theme_g1) 15%, transparent);
+ color: var(--theme_g1);
+}
+
+.status-badge.failure {
+ background: color-mix(in srgb, var(--theme_fail) 15%, transparent);
+ color: var(--theme_fail);
+}
+
+/* health dots -------------------------------------------------------------- */
+
+.health-dot {
+ display: inline-block;
+ width: 10px;
+ height: 10px;
+ border-radius: 50%;
+ background: var(--theme_g1);
+}
+
+.health-green {
+ background: var(--theme_ok);
+}
+
+.health-yellow {
+ background: var(--theme_warn);
+}
+
+.health-red {
+ background: var(--theme_fail);
+}
+
+/* inline progress bar ------------------------------------------------------ */
+
+.progress-bar {
+ width: 100%;
+ height: 8px;
+ background: var(--theme_border_subtle);
+ border-radius: 4px;
+ overflow: hidden;
+ margin-top: 8px;
+}
+
+.progress-fill {
+ height: 100%;
+ background: var(--theme_p0);
+ transition: width 0.3s ease;
+}
+
+/* stats row (label + value pair) ------------------------------------------- */
+
+.stats-row {
+ display: flex;
+ justify-content: space-between;
+ margin-bottom: 12px;
+ padding: 8px 0;
+ border-bottom: 1px solid var(--theme_border_subtle);
+}
+
+.stats-row:last-child {
+ border-bottom: none;
+ margin-bottom: 0;
+}
+
+.stats-label {
+ color: var(--theme_g1);
+ font-size: 13px;
+}
+
+.stats-value {
+ color: var(--theme_bright);
+ font-weight: 600;
+ font-size: 13px;
+}
+
+/* detail tag (inline badge) ------------------------------------------------ */
+
+.detail-tag {
+ display: inline-block;
+ padding: 2px 8px;
+ border-radius: 4px;
+ background: var(--theme_border_subtle);
+ color: var(--theme_g0);
+ font-size: 11px;
+ margin: 2px 4px 2px 0;
+}
+
+/* timestamp ---------------------------------------------------------------- */
+
+.timestamp {
+ font-size: 12px;
+ color: var(--theme_faint);
+}
+
+/* inline error ------------------------------------------------------------- */
+
+.error {
+ color: var(--theme_fail);
+ padding: 12px;
+ background: var(--theme_er);
+ border-radius: 6px;
+ margin: 20px 0;
+ font-size: 13px;
+}
+
+/* empty state -------------------------------------------------------------- */
+
+.empty-state {
+ color: var(--theme_faint);
+ font-size: 13px;
+ padding: 20px 0;
+ text-align: center;
+}
+
+/* header layout ------------------------------------------------------------ */
+
+.header {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ margin-bottom: 24px;
+}
+
+/* history tabs ------------------------------------------------------------- */
+
+.history-tabs {
+ display: flex;
+ gap: 4px;
+ background: var(--theme_g4);
+ border-radius: 6px;
+ padding: 2px;
+}
+
+.history-tab {
+ background: transparent;
+ color: var(--theme_g1);
+ font-size: 12px;
+ font-weight: 600;
+ padding: 4px 12px;
+ border-radius: 4px;
+ text-transform: uppercase;
+ letter-spacing: 0.5px;
+}
+
+.history-tab:hover {
+ color: var(--theme_g0);
+}
+
+.history-tab.active {
+ background: var(--theme_g2);
+ color: var(--theme_bright);
+}
diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp
index f9c2bc8ff..42df0520f 100644
--- a/src/zenserver/frontend/zipfs.cpp
+++ b/src/zenserver/frontend/zipfs.cpp
@@ -149,13 +149,25 @@ ZipFs::ZipFs(IoBuffer&& Buffer)
IoBuffer
ZipFs::GetFile(const std::string_view& FileName) const
{
- FileMap::iterator Iter = m_Files.find(FileName);
- if (Iter == m_Files.end())
{
- return {};
+ RwLock::SharedLockScope _(m_FilesLock);
+
+ FileMap::const_iterator Iter = m_Files.find(FileName);
+ if (Iter == m_Files.end())
+ {
+ return {};
+ }
+
+ const FileItem& Item = Iter->second;
+ if (Item.GetSize() > 0)
+ {
+ return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize());
+ }
}
- FileItem& Item = Iter->second;
+ RwLock::ExclusiveLockScope _(m_FilesLock);
+
+ FileItem& Item = m_Files.find(FileName)->second;
if (Item.GetSize() > 0)
{
return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize());
diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h
index 1fa7da451..645121693 100644
--- a/src/zenserver/frontend/zipfs.h
+++ b/src/zenserver/frontend/zipfs.h
@@ -3,23 +3,23 @@
#pragma once
#include <zencore/iobuffer.h>
+#include <zencore/thread.h>
#include <unordered_map>
namespace zen {
-//////////////////////////////////////////////////////////////////////////
class ZipFs
{
public:
- ZipFs() = default;
- ZipFs(IoBuffer&& Buffer);
+ explicit ZipFs(IoBuffer&& Buffer);
+
IoBuffer GetFile(const std::string_view& FileName) const;
- inline operator bool() const { return !m_Files.empty(); }
private:
using FileItem = MemoryView;
using FileMap = std::unordered_map<std::string_view, FileItem>;
+ mutable RwLock m_FilesLock;
FileMap mutable m_Files;
IoBuffer m_Buffer;
};
diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp
index 4d9da3a57..7b999ae20 100644
--- a/src/zenserver/hub/hubservice.cpp
+++ b/src/zenserver/hub/hubservice.cpp
@@ -4,10 +4,12 @@
#include "hydration.h"
+#include <zencore/assertfmt.h>
#include <zencore/compactbinarybuilder.h>
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
+#include <zencore/process.h>
#include <zencore/scopeguard.h>
#include <zencore/system.h>
#include <zenutil/zenserverprocess.h>
@@ -150,7 +152,12 @@ struct StorageServerInstance
inline uint16_t GetBasePort() const { return m_ServerInstance.GetBasePort(); }
+#if ZEN_PLATFORM_WINDOWS
+ void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; }
+#endif
+
private:
+ void WakeLocked();
RwLock m_Lock;
std::string m_ModuleId;
std::atomic<bool> m_IsProvisioned{false};
@@ -160,6 +167,9 @@ private:
std::filesystem::path m_TempDir;
std::filesystem::path m_HydrationPath;
ResourceMetrics m_ResourceMetrics;
+#if ZEN_PLATFORM_WINDOWS
+ JobObject* m_JobObject = nullptr;
+#endif
void SpawnServerProcess();
@@ -186,10 +196,13 @@ StorageServerInstance::~StorageServerInstance()
void
StorageServerInstance::SpawnServerProcess()
{
- ZEN_ASSERT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId);
+ ZEN_ASSERT_FORMAT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId);
m_ServerInstance.SetServerExecutablePath(GetRunningExecutablePath());
m_ServerInstance.SetDataDir(m_BaseDir);
+#if ZEN_PLATFORM_WINDOWS
+ m_ServerInstance.SetJobObject(m_JobObject);
+#endif
const uint16_t BasePort = m_ServerInstance.SpawnServerAndWaitUntilReady();
ZEN_DEBUG("Storage server instance for module '{}' started, listening on port {}", m_ModuleId, BasePort);
@@ -211,7 +224,7 @@ StorageServerInstance::Provision()
if (m_IsHibernated)
{
- Wake();
+ WakeLocked();
}
else
{
@@ -294,9 +307,14 @@ StorageServerInstance::Hibernate()
void
StorageServerInstance::Wake()
{
- // Start server in-place using existing data
-
RwLock::ExclusiveLockScope _(m_Lock);
+ WakeLocked();
+}
+
+void
+StorageServerInstance::WakeLocked()
+{
+ // Start server in-place using existing data
if (!m_IsHibernated)
{
@@ -305,7 +323,7 @@ StorageServerInstance::Wake()
return;
}
- ZEN_ASSERT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId);
+ ZEN_ASSERT_FORMAT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId);
try
{
@@ -374,6 +392,21 @@ struct HttpHubService::Impl
// flexibility, and to allow running multiple hubs on the same host if
// necessary.
m_RunEnvironment.SetNextPortNumber(21000);
+
+#if ZEN_PLATFORM_WINDOWS
+ if (m_UseJobObject)
+ {
+ m_JobObject.Initialize();
+ if (m_JobObject.IsValid())
+ {
+ ZEN_INFO("Job object initialized for hub service child process management");
+ }
+ else
+ {
+ ZEN_WARN("Failed to initialize job object; child processes will not be auto-terminated on hub crash");
+ }
+ }
+#endif
}
void Cleanup()
@@ -416,6 +449,12 @@ struct HttpHubService::Impl
IsNewInstance = true;
auto NewInstance =
std::make_unique<StorageServerInstance>(m_RunEnvironment, ModuleId, m_FileHydrationPath, m_HydrationTempPath);
+#if ZEN_PLATFORM_WINDOWS
+ if (m_JobObject.IsValid())
+ {
+ NewInstance->SetJobObject(&m_JobObject);
+ }
+#endif
Instance = NewInstance.get();
m_Instances.emplace(std::string(ModuleId), std::move(NewInstance));
@@ -573,10 +612,15 @@ struct HttpHubService::Impl
inline int GetInstanceLimit() { return m_InstanceLimit; }
inline int GetMaxInstanceCount() { return m_MaxInstanceCount; }
+ bool m_UseJobObject = true;
+
private:
- ZenServerEnvironment m_RunEnvironment;
- std::filesystem::path m_FileHydrationPath;
- std::filesystem::path m_HydrationTempPath;
+ ZenServerEnvironment m_RunEnvironment;
+ std::filesystem::path m_FileHydrationPath;
+ std::filesystem::path m_HydrationTempPath;
+#if ZEN_PLATFORM_WINDOWS
+ JobObject m_JobObject;
+#endif
RwLock m_Lock;
std::unordered_map<std::string, std::unique_ptr<StorageServerInstance>> m_Instances;
std::unordered_set<std::string> m_DeprovisioningModules;
@@ -802,7 +846,7 @@ HttpHubService::HttpHubService(std::filesystem::path HubBaseDir, std::filesystem
Obj << "currentInstanceCount" << m_Impl->GetInstanceCount();
Obj << "maxInstanceCount" << m_Impl->GetMaxInstanceCount();
Obj << "instanceLimit" << m_Impl->GetInstanceLimit();
- Req.ServerRequest().WriteResponse(HttpResponseCode::OK);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
},
HttpVerb::kGet);
}
@@ -811,6 +855,12 @@ HttpHubService::~HttpHubService()
{
}
+void
+HttpHubService::SetUseJobObject(bool Enable)
+{
+ m_Impl->m_UseJobObject = Enable;
+}
+
const char*
HttpHubService::BaseUri() const
{
diff --git a/src/zenserver/hub/hubservice.h b/src/zenserver/hub/hubservice.h
index 1a5a8c57c..ef24bba69 100644
--- a/src/zenserver/hub/hubservice.h
+++ b/src/zenserver/hub/hubservice.h
@@ -28,6 +28,13 @@ public:
void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId);
+ /** Enable or disable the use of a Windows Job Object for child process management.
+ * When enabled, all spawned child processes are assigned to a job object with
+ * JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, ensuring children are terminated if the hub
+ * crashes or is force-killed. Must be called before Initialize(). No-op on non-Windows.
+ */
+ void SetUseJobObject(bool Enable);
+
private:
HttpRequestRouter m_Router;
diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp
index 7a4ba951d..c6d2dc8d4 100644
--- a/src/zenserver/hub/zenhubserver.cpp
+++ b/src/zenserver/hub/zenhubserver.cpp
@@ -105,7 +105,7 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState:
void
ZenHubServer::Cleanup()
{
- ZEN_TRACE_CPU("ZenStorageServer::Cleanup");
+ ZEN_TRACE_CPU("ZenHubServer::Cleanup");
ZEN_INFO(ZEN_APP_NAME " cleaning up");
try
{
@@ -115,6 +115,8 @@ ZenHubServer::Cleanup()
m_IoRunner.join();
}
+ ShutdownServices();
+
if (m_Http)
{
m_Http->Close();
@@ -143,6 +145,8 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig)
ZEN_INFO("instantiating hub service");
m_HubService = std::make_unique<HttpHubService>(ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers");
m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId);
+
+ m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService);
}
void
@@ -159,6 +163,11 @@ ZenHubServer::RegisterServices(const ZenHubServerConfig& ServerConfig)
{
m_Http->RegisterService(*m_ApiService);
}
+
+ if (m_FrontendService)
+ {
+ m_Http->RegisterService(*m_FrontendService);
+ }
}
void
diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h
index ac14362f0..4c56fdce5 100644
--- a/src/zenserver/hub/zenhubserver.h
+++ b/src/zenserver/hub/zenhubserver.h
@@ -2,6 +2,7 @@
#pragma once
+#include "frontend/frontend.h"
#include "zenserver.h"
namespace cxxopts {
@@ -81,8 +82,9 @@ private:
std::filesystem::path m_ContentRoot;
bool m_DebugOptionForcedCrash = false;
- std::unique_ptr<HttpHubService> m_HubService;
- std::unique_ptr<HttpApiService> m_ApiService;
+ std::unique_ptr<HttpHubService> m_HubService;
+ std::unique_ptr<HttpApiService> m_ApiService;
+ std::unique_ptr<HttpFrontendService> m_FrontendService;
void InitializeState(const ZenHubServerConfig& ServerConfig);
void InitializeServices(const ZenHubServerConfig& ServerConfig);
diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp
index 3a58d1f4a..09ecc48e5 100644
--- a/src/zenserver/main.cpp
+++ b/src/zenserver/main.cpp
@@ -19,10 +19,13 @@
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zentelemetry/otlptrace.h>
-#include <zenutil/commandlineoptions.h>
+#include <zenutil/config/commandlineoptions.h>
#include <zenutil/service.h>
#include "diag/logging.h"
+
+#include "compute/computeserver.h"
+
#include "storage/storageconfig.h"
#include "storage/zenstorageserver.h"
@@ -38,7 +41,6 @@
// in some shared code into the executable
#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
# include <zencore/testing.h>
#endif
@@ -61,11 +63,19 @@ namespace zen {
#if ZEN_PLATFORM_WINDOWS
-template<class T>
+/** Windows Service wrapper for Zen servers
+ *
+ * This class wraps a Zen server main entry point (the Main template parameter)
+ * into a Windows Service by implementing the WindowsService interface.
+ *
+ * The Main type needs to implement the virtual functions from the ZenServerMain
+ * base class, which provides the actual server logic.
+ */
+template<class Main>
class ZenWindowsService : public WindowsService
{
public:
- ZenWindowsService(typename T::Config& ServerOptions) : m_EntryPoint(ServerOptions) {}
+ ZenWindowsService(typename Main::Config& ServerOptions) : m_EntryPoint(ServerOptions) {}
ZenWindowsService(const ZenWindowsService&) = delete;
ZenWindowsService& operator=(const ZenWindowsService&) = delete;
@@ -73,7 +83,7 @@ public:
virtual int Run() override { return m_EntryPoint.Run(); }
private:
- T m_EntryPoint;
+ Main m_EntryPoint;
};
#endif // ZEN_PLATFORM_WINDOWS
@@ -84,6 +94,23 @@ private:
namespace zen {
+/** Application main entry point template
+ *
+ * This function handles common application startup tasks while allowing
+ * different server types to be plugged in via the Main template parameter.
+ *
+ * On Windows, this function also handles platform-specific service
+ * installation and uninstallation.
+ *
+ * The Main type needs to implement the virtual functions from the ZenServerMain
+ * base class, which provides the actual server logic.
+ *
+ * The Main type is also expected to provide the following members:
+ *
+ * typedef Config -- Server configuration type, derived from ZenServerConfig
+ * typedef Configurator -- Server configuration handler type, implements ZenServerConfiguratorBase
+ *
+ */
template<class Main>
int
AppMain(int argc, char* argv[])
@@ -219,7 +246,7 @@ test_main(int argc, char** argv)
# endif // ZEN_PLATFORM_WINDOWS
zen::logging::InitializeLogging();
- zen::logging::SetLogLevel(zen::logging::level::Debug);
+ zen::logging::SetLogLevel(zen::logging::Debug);
zen::MaximizeOpenFileCount();
@@ -239,16 +266,31 @@ main(int argc, char* argv[])
using namespace zen;
using namespace std::literals;
+ // note: doctest has locally (in thirdparty) been fixed to not cause shutdown
+ // crashes due to TLS destructors
+ //
+ // mimalloc on the other hand might still be causing issues, in which case
+ // we should work out either how to eliminate the mimalloc dependency or how
+ // to configure it in a way that doesn't cause shutdown issues
+
+#if 0
auto _ = zen::MakeGuard([] {
// Allow some time for worker threads to unravel, in an effort
- // to prevent shutdown races in TLS object destruction
+ // to prevent shutdown races in TLS object destruction, mainly due to
+ // threads which we don't directly control (Windows thread pool) and
+ // therefore can't join.
+ //
+ // This isn't a great solution, but for now it seems to help reduce
+ // shutdown crashes observed in some situations.
WaitForThreads(1000);
});
+#endif
enum
{
kHub,
kStore,
+ kCompute,
kTest
} ServerMode = kStore;
@@ -258,10 +300,14 @@ main(int argc, char* argv[])
{
ServerMode = kHub;
}
- else if (argv[1] == "store"sv)
+ else if ((argv[1] == "store"sv) || (argv[1] == "storage"sv))
{
ServerMode = kStore;
}
+ else if (argv[1] == "compute"sv)
+ {
+ ServerMode = kCompute;
+ }
else if (argv[1] == "test"sv)
{
ServerMode = kTest;
@@ -280,6 +326,13 @@ main(int argc, char* argv[])
break;
case kHub:
return AppMain<ZenHubServerMain>(argc, argv);
+ case kCompute:
+#if ZEN_WITH_COMPUTE_SERVICES
+ return AppMain<ZenComputeServerMain>(argc, argv);
+#else
+ fprintf(stderr, "compute services are not compiled in!\n");
+ exit(5);
+#endif
default:
case kStore:
return AppMain<ZenStorageServerMain>(argc, argv);
diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp
new file mode 100644
index 000000000..05be3c814
--- /dev/null
+++ b/src/zenserver/sessions/httpsessions.cpp
@@ -0,0 +1,264 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpsessions.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/trace.h>
+#include "sessions.h"
+
+namespace zen {
+using namespace std::literals;
+
+HttpSessionsService::HttpSessionsService(HttpStatusService& StatusService, HttpStatsService& StatsService, SessionsService& Sessions)
+: m_Log(logging::Get("sessions"))
+, m_StatusService(StatusService)
+, m_StatsService(StatsService)
+, m_Sessions(Sessions)
+{
+ Initialize();
+}
+
+HttpSessionsService::~HttpSessionsService()
+{
+ m_StatsService.UnregisterHandler("sessions", *this);
+ m_StatusService.UnregisterHandler("sessions", *this);
+}
+
+const char*
+HttpSessionsService::BaseUri() const
+{
+ return "/sessions/";
+}
+
+void
+HttpSessionsService::HandleRequest(HttpServerRequest& Request)
+{
+ metrics::OperationTiming::Scope $(m_HttpRequests);
+
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ return Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
+ }
+}
+
+CbObject
+HttpSessionsService::CollectStats()
+{
+ ZEN_TRACE_CPU("SessionsService::Stats");
+ CbObjectWriter Cbo;
+
+ EmitSnapshot("requests", m_HttpRequests, Cbo);
+
+ Cbo.BeginObject("sessions");
+ {
+ Cbo << "readcount" << m_SessionsStats.SessionReadCount;
+ Cbo << "writecount" << m_SessionsStats.SessionWriteCount;
+ Cbo << "deletecount" << m_SessionsStats.SessionDeleteCount;
+ Cbo << "listcount" << m_SessionsStats.SessionListCount;
+ Cbo << "requestcount" << m_SessionsStats.RequestCount;
+ Cbo << "badrequestcount" << m_SessionsStats.BadRequestCount;
+ Cbo << "count" << m_Sessions.GetSessionCount();
+ }
+ Cbo.EndObject();
+
+ return Cbo.Save();
+}
+
+void
+HttpSessionsService::HandleStatsRequest(HttpServerRequest& HttpReq)
+{
+ HttpReq.WriteResponse(HttpResponseCode::OK, CollectStats());
+}
+
+void
+HttpSessionsService::HandleStatusRequest(HttpServerRequest& Request)
+{
+ ZEN_TRACE_CPU("HttpSessionsService::Status");
+ CbObjectWriter Cbo;
+ Cbo << "ok" << true;
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+void
+HttpSessionsService::Initialize()
+{
+ using namespace std::literals;
+
+ ZEN_INFO("Initializing Sessions Service");
+
+ static constexpr AsciiSet ValidHexCharactersSet{"0123456789abcdefABCDEF"};
+
+ m_Router.AddMatcher("session_id", [](std::string_view Str) -> bool {
+ return Str.length() == Oid::StringLength && AsciiSet::HasOnly(Str, ValidHexCharactersSet);
+ });
+
+ m_Router.RegisterRoute(
+ "list",
+ [this](HttpRouterRequest& Req) { ListSessionsRequest(Req); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{session_id}",
+ [this](HttpRouterRequest& Req) { SessionRequest(Req); },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "",
+ [this](HttpRouterRequest& Req) { ListSessionsRequest(Req); },
+ HttpVerb::kGet);
+
+ m_StatsService.RegisterHandler("sessions", *this);
+ m_StatusService.RegisterHandler("sessions", *this);
+}
+
+static void
+WriteSessionInfo(CbWriter& Writer, const SessionsService::SessionInfo& Info)
+{
+ Writer << "id" << Info.Id;
+ if (!Info.AppName.empty())
+ {
+ Writer << "appname" << Info.AppName;
+ }
+ if (Info.JobId != Oid::Zero)
+ {
+ Writer << "jobid" << Info.JobId;
+ }
+ Writer << "created_at" << Info.CreatedAt;
+ Writer << "updated_at" << Info.UpdatedAt;
+
+ if (Info.Metadata.GetSize() > 0)
+ {
+ Writer.BeginObject("metadata");
+ for (const CbField& Field : Info.Metadata)
+ {
+ Writer.AddField(Field);
+ }
+ Writer.EndObject();
+ }
+}
+
+void
+HttpSessionsService::ListSessionsRequest(HttpRouterRequest& Req)
+{
+ HttpServerRequest& ServerRequest = Req.ServerRequest();
+
+ m_SessionsStats.SessionListCount++;
+ m_SessionsStats.RequestCount++;
+
+ std::vector<Ref<SessionsService::Session>> Sessions = m_Sessions.GetSessions();
+
+ CbObjectWriter Response;
+ Response.BeginArray("sessions");
+ for (const Ref<SessionsService::Session>& Session : Sessions)
+ {
+ Response.BeginObject();
+ {
+ WriteSessionInfo(Response, Session->Info());
+ }
+ Response.EndObject();
+ }
+ Response.EndArray();
+
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, Response.Save());
+}
+
+void
+HttpSessionsService::SessionRequest(HttpRouterRequest& Req)
+{
+ HttpServerRequest& ServerRequest = Req.ServerRequest();
+
+ const Oid SessionId = Oid::TryFromHexString(Req.GetCapture(1));
+ if (SessionId == Oid::Zero)
+ {
+ m_SessionsStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Invalid session id '{}'", Req.GetCapture(1)));
+ }
+
+ m_SessionsStats.RequestCount++;
+
+ switch (ServerRequest.RequestVerb())
+ {
+ case HttpVerb::kPost:
+ case HttpVerb::kPut:
+ {
+ IoBuffer Payload = ServerRequest.ReadPayload();
+ CbObject RequestObject;
+
+ if (Payload.GetSize() > 0)
+ {
+ if (CbValidateError ValidationResult = ValidateCompactBinary(Payload.GetView(), CbValidateMode::All);
+ ValidationResult != CbValidateError::None)
+ {
+ m_SessionsStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Invalid payload: {}", zen::ToString(ValidationResult)));
+ }
+ RequestObject = LoadCompactBinaryObject(Payload);
+ }
+
+ if (ServerRequest.RequestVerb() == HttpVerb::kPost)
+ {
+ std::string AppName(RequestObject["appname"sv].AsString());
+ Oid JobId = RequestObject["jobid"sv].AsObjectId();
+ CbObjectView MetadataView = RequestObject["metadata"sv].AsObjectView();
+
+ m_SessionsStats.SessionWriteCount++;
+ if (m_Sessions.RegisterSession(SessionId, std::move(AppName), JobId, MetadataView))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::Created, HttpContentType::kText, fmt::format("{}", SessionId));
+ }
+ else
+ {
+ // Already exists - try update instead
+ if (m_Sessions.UpdateSession(SessionId, MetadataView))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("{}", SessionId));
+ }
+ return ServerRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ else
+ {
+ // PUT - update only
+ m_SessionsStats.SessionWriteCount++;
+ if (m_Sessions.UpdateSession(SessionId, RequestObject["metadata"sv].AsObjectView()))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("{}", SessionId));
+ }
+ return ServerRequest.WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("Session '{}' not found", SessionId));
+ }
+ }
+ case HttpVerb::kGet:
+ {
+ m_SessionsStats.SessionReadCount++;
+ Ref<SessionsService::Session> Session = m_Sessions.GetSession(SessionId);
+ if (Session)
+ {
+ CbObjectWriter Response;
+ WriteSessionInfo(Response, Session->Info());
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, Response.Save());
+ }
+ return ServerRequest.WriteResponse(HttpResponseCode::NotFound);
+ }
+ case HttpVerb::kDelete:
+ {
+ m_SessionsStats.SessionDeleteCount++;
+ if (m_Sessions.RemoveSession(SessionId))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ return ServerRequest.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h
new file mode 100644
index 000000000..e07f3b59b
--- /dev/null
+++ b/src/zenserver/sessions/httpsessions.h
@@ -0,0 +1,55 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include <zenhttp/httpstats.h>
+#include <zenhttp/httpstatus.h>
+#include <zentelemetry/stats.h>
+
+namespace zen {
+
+class SessionsService;
+
+class HttpSessionsService final : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider
+{
+public:
+ HttpSessionsService(HttpStatusService& StatusService, HttpStatsService& StatsService, SessionsService& Sessions);
+ virtual ~HttpSessionsService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+
+private:
+ struct SessionsStats
+ {
+ std::atomic_uint64_t SessionReadCount{};
+ std::atomic_uint64_t SessionWriteCount{};
+ std::atomic_uint64_t SessionDeleteCount{};
+ std::atomic_uint64_t SessionListCount{};
+ std::atomic_uint64_t RequestCount{};
+ std::atomic_uint64_t BadRequestCount{};
+ };
+
+ inline LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+
+ void Initialize();
+
+ void ListSessionsRequest(HttpRouterRequest& Req);
+ void SessionRequest(HttpRouterRequest& Req);
+
+ HttpStatusService& m_StatusService;
+ HttpStatsService& m_StatsService;
+ HttpRequestRouter m_Router;
+ SessionsService& m_Sessions;
+ SessionsStats m_SessionsStats;
+ metrics::OperationTiming m_HttpRequests;
+};
+
+} // namespace zen
diff --git a/src/zenserver/sessions/sessions.cpp b/src/zenserver/sessions/sessions.cpp
new file mode 100644
index 000000000..f73aa40ff
--- /dev/null
+++ b/src/zenserver/sessions/sessions.cpp
@@ -0,0 +1,150 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "sessions.h"
+
+#include <zencore/basicfile.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+
+namespace zen {
+using namespace std::literals;
+
+class SessionLog : public TRefCounted<SessionLog>
+{
+public:
+ SessionLog(std::filesystem::path LogFilePath) { m_LogFile.Open(LogFilePath, BasicFile::Mode::kWrite); }
+
+private:
+ BasicFile m_LogFile;
+};
+
+class SessionLogStore
+{
+public:
+ SessionLogStore(std::filesystem::path StoragePath) : m_StoragePath(std::move(StoragePath)) {}
+
+ ~SessionLogStore() = default;
+
+ Ref<SessionLog> GetLogForSession(const Oid& SessionId)
+ {
+ // For now, just return a new log for each session. We can implement actual log storage and retrieval later.
+ return Ref(new SessionLog(m_StoragePath / (SessionId.ToString() + ".log")));
+ }
+
+ Ref<SessionLog> CreateLogForSession(const Oid& SessionId)
+ {
+ // For now, just return a new log for each session. We can implement actual log storage and retrieval later.
+ return Ref(new SessionLog(m_StoragePath / (SessionId.ToString() + ".log")));
+ }
+
+private:
+ std::filesystem::path m_StoragePath;
+};
+
+SessionsService::Session::Session(const SessionInfo& Info) : m_Info(Info)
+{
+}
+SessionsService::Session::~Session() = default;
+
+//////////////////////////////////////////////////////////////////////////
+
+SessionsService::SessionsService() : m_Log(logging::Get("sessions"))
+{
+}
+
+SessionsService::~SessionsService() = default;
+
+bool
+SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, const Oid& JobId, CbObjectView Metadata)
+{
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+
+ if (m_Sessions.contains(SessionId))
+ {
+ return false;
+ }
+
+ const DateTime Now = DateTime::Now();
+ m_Sessions.emplace(SessionId,
+ Ref(new Session(SessionInfo{.Id = SessionId,
+ .AppName = std::move(AppName),
+ .JobId = JobId,
+ .Metadata = CbObject::Clone(Metadata),
+ .CreatedAt = Now,
+ .UpdatedAt = Now})));
+
+ ZEN_INFO("Session {} registered (AppName: {}, JobId: {})", SessionId, AppName, JobId);
+ return true;
+}
+
+bool
+SessionsService::UpdateSession(const Oid& SessionId, CbObjectView Metadata)
+{
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+
+ auto It = m_Sessions.find(SessionId);
+ if (It == m_Sessions.end())
+ {
+ return false;
+ }
+
+ It.value()->UpdateMetadata(Metadata);
+
+ const SessionInfo& Info = It.value()->Info();
+ ZEN_DEBUG("Session {} updated (AppName: {}, JobId: {})", SessionId, Info.AppName, Info.JobId);
+ return true;
+}
+
+Ref<SessionsService::Session>
+SessionsService::GetSession(const Oid& SessionId) const
+{
+ RwLock::SharedLockScope Lock(m_Lock);
+
+ auto It = m_Sessions.find(SessionId);
+ if (It == m_Sessions.end())
+ {
+ return {};
+ }
+
+ return It->second;
+}
+
+std::vector<Ref<SessionsService::Session>>
+SessionsService::GetSessions() const
+{
+ RwLock::SharedLockScope Lock(m_Lock);
+
+ std::vector<Ref<Session>> Result;
+ Result.reserve(m_Sessions.size());
+ for (const auto& [Id, SessionRef] : m_Sessions)
+ {
+ Result.push_back(SessionRef);
+ }
+ return Result;
+}
+
+bool
+SessionsService::RemoveSession(const Oid& SessionId)
+{
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+
+ auto It = m_Sessions.find(SessionId);
+ if (It == m_Sessions.end())
+ {
+ return false;
+ }
+
+ ZEN_INFO("Session {} removed (AppName: {}, JobId: {})", SessionId, It.value()->Info().AppName, It.value()->Info().JobId);
+
+ m_Sessions.erase(It);
+ return true;
+}
+
+uint64_t
+SessionsService::GetSessionCount() const
+{
+ RwLock::SharedLockScope Lock(m_Lock);
+ return m_Sessions.size();
+}
+
+} // namespace zen
diff --git a/src/zenserver/sessions/sessions.h b/src/zenserver/sessions/sessions.h
new file mode 100644
index 000000000..db9704430
--- /dev/null
+++ b/src/zenserver/sessions/sessions.h
@@ -0,0 +1,83 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/logbase.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <optional>
+#include <string>
+#include <vector>
+
+namespace zen {
+
+class SessionLogStore;
+class SessionLog;
+
+/** Session tracker
+ *
+ * Acts as a log and session info concentrator when dealing with multiple
+ * servers and external processes acting as a group.
+ */
+
+class SessionsService
+{
+public:
+ struct SessionInfo
+ {
+ Oid Id;
+ std::string AppName;
+ Oid JobId;
+ CbObject Metadata;
+ DateTime CreatedAt;
+ DateTime UpdatedAt;
+ };
+
+ class Session : public TRefCounted<Session>
+ {
+ public:
+ Session(const SessionInfo& Info);
+ ~Session();
+
+ Session(Session&&) = delete;
+ Session& operator=(Session&&) = delete;
+
+ const SessionInfo& Info() const { return m_Info; }
+ void UpdateMetadata(CbObjectView Metadata)
+ {
+ // Should this be additive rather than replacing the whole thing? We'll see.
+ m_Info.Metadata = CbObject::Clone(Metadata);
+ m_Info.UpdatedAt = DateTime::Now();
+ }
+
+ private:
+ SessionInfo m_Info;
+ Ref<SessionLog> m_Log;
+ };
+
+ SessionsService();
+ ~SessionsService();
+
+ bool RegisterSession(const Oid& SessionId, std::string AppName, const Oid& JobId, CbObjectView Metadata);
+ bool UpdateSession(const Oid& SessionId, CbObjectView Metadata);
+ Ref<Session> GetSession(const Oid& SessionId) const;
+ std::vector<Ref<Session>> GetSessions() const;
+ bool RemoveSession(const Oid& SessionId);
+ uint64_t GetSessionCount() const;
+
+private:
+ LoggerRef& Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ mutable RwLock m_Lock;
+ tsl::robin_map<Oid, Ref<Session>, Oid::Hasher> m_Sessions;
+ std::unique_ptr<SessionLogStore> m_SessionLogs;
+};
+
+} // namespace zen
diff --git a/src/zenserver/storage/admin/admin.cpp b/src/zenserver/storage/admin/admin.cpp
index 19155e02b..c9f999c69 100644
--- a/src/zenserver/storage/admin/admin.cpp
+++ b/src/zenserver/storage/admin/admin.cpp
@@ -716,7 +716,7 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler,
"logs",
[this](HttpRouterRequest& Req) {
CbObjectWriter Obj;
- auto LogLevel = logging::level::ToStringView(logging::GetLogLevel());
+ auto LogLevel = logging::ToStringView(logging::GetLogLevel());
Obj.AddString("loglevel", std::string_view(LogLevel.data(), LogLevel.size()));
Obj.AddString("Logfile", PathToUtf8(m_LogPaths.AbsLogPath));
Obj.BeginObject("cache");
@@ -767,8 +767,8 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler,
}
if (std::string Param(Params.GetValue("loglevel")); Param.empty() == false)
{
- logging::level::LogLevel NewLevel = logging::level::ParseLogLevelString(Param);
- std::string_view LogLevel = logging::level::ToStringView(NewLevel);
+ logging::LogLevel NewLevel = logging::ParseLogLevelString(Param);
+ std::string_view LogLevel = logging::ToStringView(NewLevel);
if (LogLevel != Param)
{
return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest,
diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp
index f5ba30616..de9589078 100644
--- a/src/zenserver/storage/buildstore/httpbuildstore.cpp
+++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp
@@ -71,7 +71,7 @@ HttpBuildStoreService::Initialize()
m_Router.RegisterRoute(
"{namespace}/{bucket}/{buildid}/blobs/{hash}",
[this](HttpRouterRequest& Req) { GetBlobRequest(Req); },
- HttpVerb::kGet);
+ HttpVerb::kGet | HttpVerb::kPost);
m_Router.RegisterRoute(
"{namespace}/{bucket}/{buildid}/blobs/putBlobMetadata",
@@ -161,14 +161,57 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req)
HttpContentType::kText,
fmt::format("Invalid blob hash '{}'", Hash));
}
- zen::HttpRanges Ranges;
- bool HasRange = ServerRequest.TryGetRanges(Ranges);
- if (Ranges.size() > 1)
+
+ std::vector<std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs;
+ if (ServerRequest.RequestVerb() == HttpVerb::kPost)
{
- // Only a single range is supported
- return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
- HttpContentType::kText,
- "Multiple ranges in blob request is not supported");
+ CbObject RangePayload = ServerRequest.ReadPayloadObject();
+ if (RangePayload)
+ {
+ CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView();
+ OffsetAndLengthPairs.reserve(RangesArray.Num());
+ for (CbFieldView FieldView : RangesArray)
+ {
+ CbObjectView RangeView = FieldView.AsObjectView();
+ uint64_t RangeOffset = RangeView["offset"sv].AsUInt64();
+ uint64_t RangeLength = RangeView["length"sv].AsUInt64();
+ OffsetAndLengthPairs.push_back(std::make_pair(RangeOffset, RangeLength));
+ }
+ if (OffsetAndLengthPairs.size() > MaxRangeCountPerRequestSupported)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Number of ranges ({}) for blob request exceeds maximum range count {}",
+ OffsetAndLengthPairs.size(),
+ MaxRangeCountPerRequestSupported));
+ }
+ }
+ if (OffsetAndLengthPairs.empty())
+ {
+ m_BuildStoreStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Fetching blob without ranges must be done with the GET verb");
+ }
+ }
+ else
+ {
+ HttpRanges Ranges;
+ bool HasRange = ServerRequest.TryGetRanges(Ranges);
+ if (HasRange)
+ {
+ if (Ranges.size() > 1)
+ {
+ // Only a single http range is supported, we have limited support for http multirange responses
+ m_BuildStoreStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Multiple ranges in blob request is only supported for {} accept type",
+ ToString(HttpContentType::kCbPackage)));
+ }
+ const HttpRange& FirstRange = Ranges.front();
+ OffsetAndLengthPairs.push_back(std::make_pair<uint64_t, uint64_t>(FirstRange.Start, FirstRange.End - FirstRange.Start + 1));
+ }
}
m_BuildStoreStats.BlobReadCount++;
@@ -179,24 +222,79 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req)
HttpContentType::kText,
fmt::format("Blob with hash '{}' could not be found", Hash));
}
- // ZEN_INFO("Fetched blob {}. Size: {}", BlobHash, Blob.GetSize());
m_BuildStoreStats.BlobHitCount++;
- if (HasRange)
+
+ if (OffsetAndLengthPairs.empty())
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob);
+ }
+
+ if (ServerRequest.AcceptContentType() == HttpContentType::kCbPackage)
{
- const HttpRange& Range = Ranges.front();
- const uint64_t BlobSize = Blob.GetSize();
- const uint64_t MaxBlobSize = Range.Start < BlobSize ? Range.Start - BlobSize : 0;
- const uint64_t RangeSize = Min(Range.End - Range.Start + 1, MaxBlobSize);
- if (Range.Start + RangeSize > BlobSize)
+ const uint64_t BlobSize = Blob.GetSize();
+
+ CbPackage ResponsePackage;
+ std::vector<IoBuffer> RangeBuffers;
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ for (const std::pair<uint64_t, uint64_t>& Range : OffsetAndLengthPairs)
{
- return ServerRequest.WriteResponse(HttpResponseCode::NoContent);
+ const uint64_t MaxBlobSize = Range.first < BlobSize ? BlobSize - Range.first : 0;
+ const uint64_t RangeSize = Min(Range.second, MaxBlobSize);
+ Writer.BeginObject();
+ {
+ if (Range.first + RangeSize <= BlobSize)
+ {
+ RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize));
+ Writer.AddInteger("offset"sv, Range.first);
+ Writer.AddInteger("length"sv, RangeSize);
+ }
+ else
+ {
+ Writer.AddInteger("offset"sv, Range.first);
+ Writer.AddInteger("length"sv, 0);
+ }
+ }
+ Writer.EndObject();
}
- Blob = IoBuffer(Blob, Range.Start, RangeSize);
- return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob);
+ Writer.EndArray();
+
+ CompositeBuffer Ranges(RangeBuffers);
+ CbAttachment PayloadAttachment(std::move(Ranges), BlobHash);
+ Writer.AddAttachment("payload", PayloadAttachment);
+
+ CbObject HeaderObject = Writer.Save();
+
+ ResponsePackage.AddAttachment(PayloadAttachment);
+ ResponsePackage.SetObject(HeaderObject);
+
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage);
+ uint64_t ResponseSize = RpcResponseBuffer.GetSize();
+ ZEN_UNUSED(ResponseSize);
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer);
}
else
{
- return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob);
+ if (OffsetAndLengthPairs.size() != 1)
+ {
+ // Only a single http range is supported, we have limited support for http multirange responses
+ m_BuildStoreStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(
+ HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Multiple ranges in blob request is only supported for {} accept type", ToString(HttpContentType::kCbPackage)));
+ }
+
+ const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengthPairs.front();
+ const uint64_t BlobSize = Blob.GetSize();
+ const uint64_t MaxBlobSize = OffsetAndLength.first < BlobSize ? BlobSize - OffsetAndLength.first : 0;
+ const uint64_t RangeSize = Min(OffsetAndLength.second, MaxBlobSize);
+ if (OffsetAndLength.first + RangeSize > BlobSize)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::NoContent);
+ }
+ Blob = IoBuffer(Blob, OffsetAndLength.first, RangeSize);
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob);
}
}
@@ -507,8 +605,8 @@ HttpBuildStoreService::BlobsExistsRequest(HttpRouterRequest& Req)
return ServerRequest.WriteResponse(HttpResponseCode::OK, ResponseObject);
}
-void
-HttpBuildStoreService::HandleStatsRequest(HttpServerRequest& Request)
+CbObject
+HttpBuildStoreService::CollectStats()
{
ZEN_TRACE_CPU("HttpBuildStoreService::Stats");
@@ -562,7 +660,13 @@ HttpBuildStoreService::HandleStatsRequest(HttpServerRequest& Request)
}
Cbo.EndObject();
- return Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ return Cbo.Save();
+}
+
+void
+HttpBuildStoreService::HandleStatsRequest(HttpServerRequest& Request)
+{
+ Request.WriteResponse(HttpResponseCode::OK, CollectStats());
}
void
@@ -571,6 +675,11 @@ HttpBuildStoreService::HandleStatusRequest(HttpServerRequest& Request)
ZEN_TRACE_CPU("HttpBuildStoreService::Status");
CbObjectWriter Cbo;
Cbo << "ok" << true;
+ Cbo.BeginObject("capabilities");
+ {
+ Cbo << "maxrangecountperrequest" << MaxRangeCountPerRequestSupported;
+ }
+ Cbo.EndObject(); // capabilities
Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
}
diff --git a/src/zenserver/storage/buildstore/httpbuildstore.h b/src/zenserver/storage/buildstore/httpbuildstore.h
index e10986411..2a09b71cf 100644
--- a/src/zenserver/storage/buildstore/httpbuildstore.h
+++ b/src/zenserver/storage/buildstore/httpbuildstore.h
@@ -22,8 +22,9 @@ public:
virtual const char* BaseUri() const override;
virtual void HandleRequest(zen::HttpServerRequest& Request) override;
- virtual void HandleStatsRequest(HttpServerRequest& Request) override;
- virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
private:
struct BuildStoreStats
@@ -45,6 +46,8 @@ private:
inline LoggerRef Log() { return m_Log; }
+ static constexpr uint32_t MaxRangeCountPerRequestSupported = 256u;
+
LoggerRef m_Log;
void PutBlobRequest(HttpRouterRequest& Req);
diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp
index 72f29d14e..06b8f6c27 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.cpp
+++ b/src/zenserver/storage/cache/httpstructuredcache.cpp
@@ -654,7 +654,7 @@ HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Reque
auto NewEnd = std::unique(AllAttachments.begin(), AllAttachments.end());
AllAttachments.erase(NewEnd, AllAttachments.end());
- uint64_t AttachmentsSize = 0;
+ std::atomic<uint64_t> AttachmentsSize = 0;
m_CidStore.IterateChunks(
AllAttachments,
@@ -746,7 +746,7 @@ HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request,
ResponseWriter << "Size" << ValuesSize;
ResponseWriter << "AttachmentCount" << ContentStats.Attachments.size();
- uint64_t AttachmentsSize = 0;
+ std::atomic<uint64_t> AttachmentsSize = 0;
WorkerThreadPool& WorkerPool = GetMediumWorkerPool(EWorkloadType::Background);
@@ -1827,8 +1827,8 @@ HttpStructuredCacheService::HandleRpcRequest(HttpServerRequest& Request, std::st
}
}
-void
-HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
+CbObject
+HttpStructuredCacheService::CollectStats()
{
ZEN_MEMSCOPE(GetCacheHttpTag());
@@ -1858,13 +1858,132 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
const CidStoreSize CidSize = m_CidStore.TotalSize();
const CacheStoreSize CacheSize = m_CacheStore.TotalSize();
+ Cbo.BeginObject("cache");
+ {
+ Cbo << "badrequestcount" << BadRequestCount;
+ Cbo.BeginObject("rpc");
+ Cbo << "count" << RpcRequests;
+ Cbo << "ops" << RpcRecordBatchRequests + RpcValueBatchRequests + RpcChunkBatchRequests;
+ Cbo.BeginObject("records");
+ Cbo << "count" << RpcRecordRequests;
+ Cbo << "ops" << RpcRecordBatchRequests;
+ Cbo.EndObject();
+ Cbo.BeginObject("values");
+ Cbo << "count" << RpcValueRequests;
+ Cbo << "ops" << RpcValueBatchRequests;
+ Cbo.EndObject();
+ Cbo.BeginObject("chunks");
+ Cbo << "count" << RpcChunkRequests;
+ Cbo << "ops" << RpcChunkBatchRequests;
+ Cbo.EndObject();
+ Cbo.EndObject();
+
+ Cbo.BeginObject("size");
+ {
+ Cbo << "disk" << CacheSize.DiskSize;
+ Cbo << "memory" << CacheSize.MemorySize;
+ }
+ Cbo.EndObject();
+
+ Cbo << "hits" << HitCount << "misses" << MissCount << "writes" << WriteCount;
+ Cbo << "hit_ratio" << (TotalCount > 0 ? (double(HitCount) / double(TotalCount)) : 0.0);
+
+ if (m_UpstreamCache.IsActive())
+ {
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount;
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ }
+
+ Cbo << "cidhits" << ChunkHitCount << "cidmisses" << ChunkMissCount << "cidwrites" << ChunkWriteCount;
+
+ {
+ ZenCacheStore::CacheStoreStats StoreStatsData = m_CacheStore.Stats();
+ Cbo.BeginObject("store");
+ Cbo << "hits" << StoreStatsData.HitCount << "misses" << StoreStatsData.MissCount << "writes" << StoreStatsData.WriteCount
+ << "rejected_writes" << StoreStatsData.RejectedWriteCount << "rejected_reads" << StoreStatsData.RejectedReadCount;
+ const uint64_t StoreTotal = StoreStatsData.HitCount + StoreStatsData.MissCount;
+ Cbo << "hit_ratio" << (StoreTotal > 0 ? (double(StoreStatsData.HitCount) / double(StoreTotal)) : 0.0);
+ EmitSnapshot("read", StoreStatsData.GetOps, Cbo);
+ EmitSnapshot("write", StoreStatsData.PutOps, Cbo);
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndObject();
+
+ if (m_UpstreamCache.IsActive())
+ {
+ EmitSnapshot("upstream_gets", m_UpstreamGetRequestTiming, Cbo);
+ Cbo.BeginObject("upstream");
+ {
+ m_UpstreamCache.GetStatus(Cbo);
+ }
+ Cbo.EndObject();
+ }
+
+ Cbo.BeginObject("cid");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "tiny" << CidSize.TinySize;
+ Cbo << "small" << CidSize.SmallSize;
+ Cbo << "large" << CidSize.LargeSize;
+ Cbo << "total" << CidSize.TotalSize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ return Cbo.Save();
+}
+
+void
+HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
+{
+ ZEN_MEMSCOPE(GetCacheHttpTag());
+
bool ShowCidStoreStats = Request.GetQueryParams().GetValue("cidstorestats") == "true";
bool ShowCacheStoreStats = Request.GetQueryParams().GetValue("cachestorestats") == "true";
- CidStoreStats CidStoreStats = {};
+ if (!ShowCidStoreStats && !ShowCacheStoreStats)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, CollectStats());
+ return;
+ }
+
+ // Full stats with optional detailed store/cid breakdowns
+
+ CbObjectWriter Cbo;
+
+ EmitSnapshot("requests", m_HttpRequests, Cbo);
+
+ const uint64_t HitCount = m_CacheStats.HitCount;
+ const uint64_t UpstreamHitCount = m_CacheStats.UpstreamHitCount;
+ const uint64_t MissCount = m_CacheStats.MissCount;
+ const uint64_t WriteCount = m_CacheStats.WriteCount;
+ const uint64_t BadRequestCount = m_CacheStats.BadRequestCount;
+ struct CidStoreStats StoreStats = m_CidStore.Stats();
+ const uint64_t ChunkHitCount = StoreStats.HitCount;
+ const uint64_t ChunkMissCount = StoreStats.MissCount;
+ const uint64_t ChunkWriteCount = StoreStats.WriteCount;
+ const uint64_t TotalCount = HitCount + MissCount;
+
+ const uint64_t RpcRequests = m_CacheStats.RpcRequests;
+ const uint64_t RpcRecordRequests = m_CacheStats.RpcRecordRequests;
+ const uint64_t RpcRecordBatchRequests = m_CacheStats.RpcRecordBatchRequests;
+ const uint64_t RpcValueRequests = m_CacheStats.RpcValueRequests;
+ const uint64_t RpcValueBatchRequests = m_CacheStats.RpcValueBatchRequests;
+ const uint64_t RpcChunkRequests = m_CacheStats.RpcChunkRequests;
+ const uint64_t RpcChunkBatchRequests = m_CacheStats.RpcChunkBatchRequests;
+
+ const CidStoreSize CidSize = m_CidStore.TotalSize();
+ const CacheStoreSize CacheSize = m_CacheStore.TotalSize();
+
+ CidStoreStats DetailedCidStoreStats = {};
if (ShowCidStoreStats)
{
- CidStoreStats = m_CidStore.Stats();
+ DetailedCidStoreStats = m_CidStore.Stats();
}
ZenCacheStore::CacheStoreStats CacheStoreStats = {};
if (ShowCacheStoreStats)
@@ -2002,8 +2121,8 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
}
Cbo.EndObject();
}
- Cbo.EndObject();
}
+ Cbo.EndObject();
if (m_UpstreamCache.IsActive())
{
@@ -2029,10 +2148,10 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
if (ShowCidStoreStats)
{
Cbo.BeginObject("store");
- Cbo << "hits" << CidStoreStats.HitCount << "misses" << CidStoreStats.MissCount << "writes" << CidStoreStats.WriteCount;
- EmitSnapshot("read", CidStoreStats.FindChunkOps, Cbo);
- EmitSnapshot("write", CidStoreStats.AddChunkOps, Cbo);
- // EmitSnapshot("exists", CidStoreStats.ContainChunkOps, Cbo);
+ Cbo << "hits" << DetailedCidStoreStats.HitCount << "misses" << DetailedCidStoreStats.MissCount << "writes"
+ << DetailedCidStoreStats.WriteCount;
+ EmitSnapshot("read", DetailedCidStoreStats.FindChunkOps, Cbo);
+ EmitSnapshot("write", DetailedCidStoreStats.AddChunkOps, Cbo);
Cbo.EndObject();
}
}
diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h
index 5a795c215..d462415d4 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.h
+++ b/src/zenserver/storage/cache/httpstructuredcache.h
@@ -102,11 +102,12 @@ private:
void HandleRpcRequest(HttpServerRequest& Request, std::string_view UriNamespace);
void HandleDetailsRequest(HttpServerRequest& Request);
- void HandleCacheRequest(HttpServerRequest& Request);
- void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace);
- void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket);
- virtual void HandleStatsRequest(HttpServerRequest& Request) override;
- virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+ void HandleCacheRequest(HttpServerRequest& Request);
+ void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace);
+ void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket);
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
bool AreDiskWritesAllowed() const;
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp
index fe32fa15b..836d84292 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.cpp
+++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp
@@ -13,7 +13,12 @@
#include <zencore/scopeguard.h>
#include <zencore/stream.h>
#include <zencore/trace.h>
+#include <zenhttp/httpclientauth.h>
#include <zenhttp/packageformat.h>
+#include <zenremotestore/builds/buildstoragecache.h>
+#include <zenremotestore/builds/buildstorageutil.h>
+#include <zenremotestore/jupiter/jupiterhost.h>
+#include <zenremotestore/operationlogoutput.h>
#include <zenremotestore/projectstore/buildsremoteprojectstore.h>
#include <zenremotestore/projectstore/fileremoteprojectstore.h>
#include <zenremotestore/projectstore/jupiterremoteprojectstore.h>
@@ -244,6 +249,22 @@ namespace {
{
std::shared_ptr<RemoteProjectStore> Store;
std::string Description;
+ double LatencySec = -1.0;
+ uint64_t MaxRangeCountPerRequest = 1;
+
+ struct Cache
+ {
+ std::unique_ptr<HttpClient> Http;
+ std::unique_ptr<BuildStorageCache> Cache;
+ Oid BuildsId = Oid::Zero;
+ std::string Description;
+ double LatencySec = -1.0;
+ uint64_t MaxRangeCountPerRequest = 1;
+ BuildStorageCache::Statistics Stats;
+ bool Populate = false;
+ };
+
+ std::unique_ptr<Cache> OptionalCache;
};
CreateRemoteStoreResult CreateRemoteStore(LoggerRef InLog,
@@ -260,7 +281,7 @@ namespace {
using namespace std::literals;
- std::shared_ptr<RemoteProjectStore> RemoteStore;
+ CreateRemoteStoreResult Result;
if (CbObjectView File = Params["file"sv].AsObjectView(); File)
{
@@ -285,7 +306,9 @@ namespace {
std::string(OptionalBaseName),
ForceDisableBlocks,
ForceEnableTempBlocks};
- RemoteStore = CreateFileRemoteStore(Log(), Options);
+ Result.Store = CreateFileRemoteStore(Log(), Options);
+ Result.LatencySec = 0.5 / 1000.0; // 0.5 ms
+ Result.MaxRangeCountPerRequest = 1024u;
}
if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud)
@@ -363,21 +386,32 @@ namespace {
bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false);
bool AssumeHttp2 = Cloud["assumehttp2"sv].AsBool(false);
- JupiterRemoteStoreOptions Options = {
- RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize},
- Url,
- std::string(Namespace),
- std::string(Bucket),
- Key,
- BaseKey,
- std::string(OpenIdProvider),
- AccessToken,
- AuthManager,
- OidcExePath,
- ForceDisableBlocks,
- ForceDisableTempBlocks,
- AssumeHttp2};
- RemoteStore = CreateJupiterRemoteStore(Log(), Options, TempFilePath, /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true);
+ if (JupiterEndpointTestResult TestResult = TestJupiterEndpoint(Url, AssumeHttp2, /*Verbose*/ false); TestResult.Success)
+ {
+ Result.LatencySec = TestResult.LatencySeconds;
+ Result.MaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest;
+
+ JupiterRemoteStoreOptions Options = {
+ RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ Url,
+ std::string(Namespace),
+ std::string(Bucket),
+ Key,
+ BaseKey,
+ std::string(OpenIdProvider),
+ AccessToken,
+ AuthManager,
+ OidcExePath,
+ ForceDisableBlocks,
+ ForceDisableTempBlocks,
+ AssumeHttp2};
+ Result.Store =
+ CreateJupiterRemoteStore(Log(), Options, TempFilePath, /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true);
+ }
+ else
+ {
+ return {nullptr, fmt::format("Unable to connect to jupiter host '{}'", Url)};
+ }
}
if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen)
@@ -393,12 +427,13 @@ namespace {
{
return {nullptr, "Missing oplog"};
}
+
ZenRemoteStoreOptions Options = {
RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize},
std::string(Url),
std::string(Project),
std::string(Oplog)};
- RemoteStore = CreateZenRemoteStore(Log(), Options, TempFilePath);
+ Result.Store = CreateZenRemoteStore(Log(), Options, TempFilePath);
}
if (CbObjectView Builds = Params["builds"sv].AsObjectView(); Builds)
@@ -471,11 +506,76 @@ namespace {
MemoryView MetaDataSection = Builds["metadata"sv].AsBinaryView();
IoBuffer MetaData(IoBuffer::Wrap, MetaDataSection.GetData(), MetaDataSection.GetSize());
+ auto EnsureHttps = [](const std::string& Host, std::string_view PreferredProtocol) {
+ if (!Host.empty() && Host.find("://"sv) == std::string::npos)
+ {
+ // Assume https URL
+ return fmt::format("{}://{}"sv, PreferredProtocol, Host);
+ }
+ return Host;
+ };
+
+ Host = EnsureHttps(Host, "https");
+ OverrideHost = EnsureHttps(OverrideHost, "https");
+ ZenHost = EnsureHttps(ZenHost, "http");
+
+ std::function<HttpClientAccessToken()> TokenProvider;
+ if (!OpenIdProvider.empty())
+ {
+ TokenProvider = httpclientauth::CreateFromOpenIdProvider(AuthManager, OpenIdProvider);
+ }
+ else if (!AccessToken.empty())
+ {
+ TokenProvider = httpclientauth::CreateFromStaticToken(AccessToken);
+ }
+ else if (!OidcExePath.empty())
+ {
+ if (auto TokenProviderMaybe = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath,
+ Host.empty() ? OverrideHost : Host,
+ /*Quiet*/ false,
+ /*Unattended*/ false,
+ /*Hidden*/ true);
+ TokenProviderMaybe)
+ {
+ TokenProvider = TokenProviderMaybe.value();
+ }
+ }
+
+ if (!TokenProvider)
+ {
+ TokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(AuthManager);
+ }
+
+ BuildStorageResolveResult ResolveResult;
+ {
+ HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient",
+ .AccessTokenProvider = TokenProvider,
+ .AssumeHttp2 = AssumeHttp2,
+ .AllowResume = true,
+ .RetryCount = 2};
+
+ std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log()));
+
+ try
+ {
+ ResolveResult = ResolveBuildStorage(*Output,
+ ClientSettings,
+ Host,
+ OverrideHost,
+ ZenHost,
+ ZenCacheResolveMode::Discovery,
+ /*Verbose*/ false);
+ }
+ catch (const std::exception& Ex)
+ {
+ return {nullptr, fmt::format("Failed resolving storage host and cache. Reason: '{}'", Ex.what())};
+ }
+ }
+ Result.LatencySec = ResolveResult.Cloud.LatencySec;
+ Result.MaxRangeCountPerRequest = ResolveResult.Cloud.Caps.MaxRangeCountPerRequest;
+
BuildsRemoteStoreOptions Options = {
RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize},
- Host,
- OverrideHost,
- ZenHost,
std::string(Namespace),
std::string(Bucket),
BuildId,
@@ -485,25 +585,43 @@ namespace {
OidcExePath,
ForceDisableBlocks,
ForceDisableTempBlocks,
- AssumeHttp2,
- PopulateCache,
MetaData,
MaximumInMemoryDownloadSize};
- RemoteStore = CreateJupiterBuildsRemoteStore(Log(),
- Options,
- TempFilePath,
- /*Quiet*/ false,
- /*Unattended*/ false,
- /*Hidden*/ true,
- GetTinyWorkerPool(EWorkloadType::Background));
+ Result.Store = CreateJupiterBuildsRemoteStore(Log(), ResolveResult, std::move(TokenProvider), Options, TempFilePath);
+
+ if (!ResolveResult.Cache.Address.empty())
+ {
+ Result.OptionalCache = std::make_unique<CreateRemoteStoreResult::Cache>();
+
+ HttpClientSettings CacheClientSettings{.LogCategory = "httpcacheclient",
+ .ConnectTimeout = std::chrono::milliseconds{3000},
+ .Timeout = std::chrono::milliseconds{30000},
+ .AssumeHttp2 = ResolveResult.Cache.AssumeHttp2,
+ .AllowResume = true,
+ .RetryCount = 0,
+ .MaximumInMemoryDownloadSize = MaximumInMemoryDownloadSize};
+
+ Result.OptionalCache->Http = std::make_unique<HttpClient>(ResolveResult.Cache.Address, CacheClientSettings);
+ Result.OptionalCache->Cache = CreateZenBuildStorageCache(*Result.OptionalCache->Http,
+ Result.OptionalCache->Stats,
+ Namespace,
+ Bucket,
+ TempFilePath,
+ GetTinyWorkerPool(EWorkloadType::Background));
+ Result.OptionalCache->BuildsId = BuildId;
+ Result.OptionalCache->LatencySec = ResolveResult.Cache.LatencySec;
+ Result.OptionalCache->MaxRangeCountPerRequest = ResolveResult.Cache.Caps.MaxRangeCountPerRequest;
+ Result.OptionalCache->Populate = PopulateCache;
+ Result.OptionalCache->Description =
+ fmt::format("[zenserver] {} namespace {} bucket {}", ResolveResult.Cache.Address, Namespace, Bucket);
+ }
}
-
- if (!RemoteStore)
+ if (!Result.Store)
{
return {nullptr, "Unknown remote store type"};
}
- return {std::move(RemoteStore), ""};
+ return Result;
}
std::pair<HttpResponseCode, std::string> ConvertResult(const RemoteProjectStore::Result& Result)
@@ -714,8 +832,8 @@ HttpProjectService::HandleRequest(HttpServerRequest& Request)
}
}
-void
-HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq)
+CbObject
+HttpProjectService::CollectStats()
{
ZEN_TRACE_CPU("ProjectService::Stats");
@@ -781,7 +899,13 @@ HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq)
}
Cbo.EndObject();
- return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ return Cbo.Save();
+}
+
+void
+HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq)
+{
+ HttpReq.WriteResponse(HttpResponseCode::OK, CollectStats());
}
void
@@ -2373,15 +2497,19 @@ HttpProjectService::HandleOplogSaveRequest(HttpRouterRequest& Req)
tsl::robin_set<IoHash, IoHash::Hasher> Attachments;
auto HasAttachment = [this](const IoHash& RawHash) { return m_CidStore.ContainsChunk(RawHash); };
- auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector<IoHash>&& ChunkHashes) {
+ auto OnNeedBlock = [&AttachmentsLock, &Attachments](ThinChunkBlockDescription&& ThinBlockDescription,
+ std::vector<uint32_t>&& NeededChunkIndexes) {
RwLock::ExclusiveLockScope _(AttachmentsLock);
- if (BlockHash != IoHash::Zero)
+ if (ThinBlockDescription.BlockHash != IoHash::Zero)
{
- Attachments.insert(BlockHash);
+ Attachments.insert(ThinBlockDescription.BlockHash);
}
else
{
- Attachments.insert(ChunkHashes.begin(), ChunkHashes.end());
+ for (uint32_t ChunkIndex : NeededChunkIndexes)
+ {
+ Attachments.insert(ThinBlockDescription.ChunkRawHashes[ChunkIndex]);
+ }
}
};
auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) {
@@ -2687,36 +2815,39 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req)
bool CleanOplog = Params["clean"].AsBool(false);
bool BoostWorkerCount = Params["boostworkercount"].AsBool(false);
bool BoostWorkerMemory = Params["boostworkermemory"sv].AsBool(false);
-
- CreateRemoteStoreResult RemoteStoreResult = CreateRemoteStore(Log(),
- Params,
- m_AuthMgr,
- MaxBlockSize,
- MaxChunkEmbedSize,
- GetMaxMemoryBufferSize(MaxBlockSize, BoostWorkerMemory),
- Oplog->TempPath());
-
- if (RemoteStoreResult.Store == nullptr)
+ EPartialBlockRequestMode PartialBlockRequestMode =
+ PartialBlockRequestModeFromString(Params["partialblockrequestmode"sv].AsString("true"));
+
+ std::shared_ptr<CreateRemoteStoreResult> RemoteStoreResult =
+ std::make_shared<CreateRemoteStoreResult>(CreateRemoteStore(Log(),
+ Params,
+ m_AuthMgr,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ GetMaxMemoryBufferSize(MaxBlockSize, BoostWorkerMemory),
+ Oplog->TempPath()));
+
+ if (RemoteStoreResult->Store == nullptr)
{
- return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, RemoteStoreResult.Description);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, RemoteStoreResult->Description);
}
- std::shared_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.Store);
- RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo();
JobId JobId = m_JobQueue.QueueJob(
fmt::format("Import oplog '{}/{}'", Project->Identifier, Oplog->OplogId()),
[this,
- ChunkStore = &m_CidStore,
- ActualRemoteStore = std::move(RemoteStore),
+ RemoteStoreResult = std::move(RemoteStoreResult),
Oplog,
Force,
IgnoreMissingAttachments,
CleanOplog,
+ PartialBlockRequestMode,
BoostWorkerCount](JobContext& Context) {
- Context.ReportMessage(fmt::format("Loading oplog '{}/{}' from {}",
- Oplog->GetOuterProjectIdentifier(),
- Oplog->OplogId(),
- ActualRemoteStore->GetInfo().Description));
+ Context.ReportMessage(
+ fmt::format("Loading oplog '{}/{}'\n Host: {}\n Cache: {}",
+ Oplog->GetOuterProjectIdentifier(),
+ Oplog->OplogId(),
+ RemoteStoreResult->Store->GetInfo().Description,
+ RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Description : "<none>"));
Ref<TransferThreadWorkers> Workers = GetThreadWorkers(BoostWorkerCount, /*SingleThreaded*/ false);
@@ -2724,16 +2855,26 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req)
WorkerThreadPool& NetworkWorkerPool = Workers->GetNetworkPool();
Context.ReportMessage(fmt::format("{}", Workers->GetWorkersInfo()));
-
- RemoteProjectStore::Result Result = LoadOplog(m_CidStore,
- *ActualRemoteStore,
- *Oplog,
- NetworkWorkerPool,
- WorkerPool,
- Force,
- IgnoreMissingAttachments,
- CleanOplog,
- &Context);
+ RemoteProjectStore::Result Result = LoadOplog(LoadOplogContext{
+ .ChunkStore = m_CidStore,
+ .RemoteStore = *RemoteStoreResult->Store,
+ .OptionalCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Cache.get() : nullptr,
+ .CacheBuildId = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->BuildsId : Oid::Zero,
+ .OptionalCacheStats = RemoteStoreResult->OptionalCache ? &RemoteStoreResult->OptionalCache->Stats : nullptr,
+ .Oplog = *Oplog,
+ .NetworkWorkerPool = NetworkWorkerPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = Force,
+ .IgnoreMissingAttachments = IgnoreMissingAttachments,
+ .CleanOplog = CleanOplog,
+ .PartialBlockRequestMode = PartialBlockRequestMode,
+ .PopulateCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Populate : false,
+ .StoreLatencySec = RemoteStoreResult->LatencySec,
+ .StoreMaxRangeCountPerRequest = RemoteStoreResult->MaxRangeCountPerRequest,
+ .CacheLatencySec = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->LatencySec : -1.0,
+ .CacheMaxRangeCountPerRequest =
+ RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->MaxRangeCountPerRequest : 0,
+ .OptionalJobContext = &Context});
auto Response = ConvertResult(Result);
ZEN_INFO("LoadOplog: Status: {} '{}'", ToString(Response.first), Response.second);
if (!IsHttpSuccessCode(Response.first))
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h
index 1d71329b1..a1f649ed6 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.h
+++ b/src/zenserver/storage/projectstore/httpprojectstore.h
@@ -51,8 +51,9 @@ public:
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
- virtual void HandleStatsRequest(HttpServerRequest& Request) override;
- virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
private:
struct ProjectStats
diff --git a/src/zenserver/storage/storageconfig.cpp b/src/zenserver/storage/storageconfig.cpp
index 99d0f89d7..ad1fb88ea 100644
--- a/src/zenserver/storage/storageconfig.cpp
+++ b/src/zenserver/storage/storageconfig.cpp
@@ -804,6 +804,7 @@ ZenStorageServerCmdLineOptions::AddCacheOptions(cxxopts::Options& options, ZenSt
cxxopts::value<uint64_t>(ServerOptions.StructuredCacheConfig.MemMaxAgeSeconds)->default_value("86400"),
"");
+ options.add_option("compute", "", "lie-cpus", "Lie to upstream about CPU capabilities", cxxopts::value<int>(ServerOptions.LieCpu), "");
options.add_option("cache",
"",
"cache-bucket-maxblocksize",
diff --git a/src/zenserver/storage/storageconfig.h b/src/zenserver/storage/storageconfig.h
index bc2dc78c9..d935ed8b3 100644
--- a/src/zenserver/storage/storageconfig.h
+++ b/src/zenserver/storage/storageconfig.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include "config/config.h"
@@ -156,6 +157,7 @@ struct ZenStorageServerConfig : public ZenServerConfig
ZenWorkspacesConfig WorksSpacesConfig;
std::filesystem::path PluginsConfigFile; // Path to plugins config file
bool ObjectStoreEnabled = false;
+ bool ComputeEnabled = true;
std::string ScrubOptions;
bool RestrictContentTypes = false;
};
diff --git a/src/zenserver/storage/workspaces/httpworkspaces.cpp b/src/zenserver/storage/workspaces/httpworkspaces.cpp
index dc4cc7e69..785dd62f0 100644
--- a/src/zenserver/storage/workspaces/httpworkspaces.cpp
+++ b/src/zenserver/storage/workspaces/httpworkspaces.cpp
@@ -110,8 +110,8 @@ HttpWorkspacesService::HandleRequest(HttpServerRequest& Request)
}
}
-void
-HttpWorkspacesService::HandleStatsRequest(HttpServerRequest& HttpReq)
+CbObject
+HttpWorkspacesService::CollectStats()
{
ZEN_TRACE_CPU("WorkspacesService::Stats");
CbObjectWriter Cbo;
@@ -150,7 +150,13 @@ HttpWorkspacesService::HandleStatsRequest(HttpServerRequest& HttpReq)
}
Cbo.EndObject();
- return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ return Cbo.Save();
+}
+
+void
+HttpWorkspacesService::HandleStatsRequest(HttpServerRequest& HttpReq)
+{
+ HttpReq.WriteResponse(HttpResponseCode::OK, CollectStats());
}
void
diff --git a/src/zenserver/storage/workspaces/httpworkspaces.h b/src/zenserver/storage/workspaces/httpworkspaces.h
index 888a34b4d..7c5ddeff1 100644
--- a/src/zenserver/storage/workspaces/httpworkspaces.h
+++ b/src/zenserver/storage/workspaces/httpworkspaces.h
@@ -29,8 +29,9 @@ public:
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
- virtual void HandleStatsRequest(HttpServerRequest& Request) override;
- virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
private:
struct WorkspacesStats
diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp
index ea05bd155..f43bb9987 100644
--- a/src/zenserver/storage/zenstorageserver.cpp
+++ b/src/zenserver/storage/zenstorageserver.cpp
@@ -33,6 +33,7 @@
#include <zenutil/service.h>
#include <zenutil/workerpools.h>
#include <zenutil/zenserverprocess.h>
+#include "../sessions/sessions.h"
#if ZEN_PLATFORM_WINDOWS
# include <zencore/windows.h>
@@ -133,7 +134,6 @@ void
ZenStorageServer::RegisterServices()
{
m_Http->RegisterService(*m_AuthService);
- m_Http->RegisterService(m_StatsService);
m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics
#if ZEN_WITH_TESTS
@@ -160,6 +160,11 @@ ZenStorageServer::RegisterServices()
m_Http->RegisterService(*m_HttpWorkspacesService);
}
+ if (m_HttpSessionsService)
+ {
+ m_Http->RegisterService(*m_HttpSessionsService);
+ }
+
m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService);
if (m_FrontendService)
@@ -182,6 +187,18 @@ ZenStorageServer::RegisterServices()
#endif // ZEN_WITH_VFS
m_Http->RegisterService(*m_AdminService);
+
+ if (m_ApiService)
+ {
+ m_Http->RegisterService(*m_ApiService);
+ }
+
+#if ZEN_WITH_COMPUTE_SERVICES
+ if (m_HttpComputeService)
+ {
+ m_Http->RegisterService(*m_HttpComputeService);
+ }
+#endif
}
void
@@ -227,6 +244,11 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
*m_Workspaces));
}
+ {
+ m_SessionsService = std::make_unique<SessionsService>();
+ m_HttpSessionsService = std::make_unique<HttpSessionsService>(m_StatusService, m_StatsService, *m_SessionsService);
+ }
+
if (ServerOptions.BuildStoreConfig.Enabled)
{
CidStoreConfiguration BuildCidConfig;
@@ -273,6 +295,16 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
m_BuildStoreService = std::make_unique<HttpBuildStoreService>(m_StatusService, m_StatsService, *m_BuildStore);
}
+#if ZEN_WITH_COMPUTE_SERVICES
+ if (ServerOptions.ComputeEnabled)
+ {
+ ZEN_OTEL_SPAN("InitializeComputeService");
+
+ m_HttpComputeService =
+ std::make_unique<compute::HttpComputeService>(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions");
+ }
+#endif
+
#if ZEN_WITH_VFS
m_VfsServiceImpl = std::make_unique<VfsServiceImpl>();
m_VfsServiceImpl->AddService(Ref<ProjectStore>(m_ProjectStore));
@@ -305,13 +337,15 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
.AttachmentPassCount = ServerOptions.GcConfig.AttachmentPassCount};
m_GcScheduler.Initialize(GcConfig);
+ m_ApiService = std::make_unique<HttpApiService>(*m_Http);
+
// Create and register admin interface last to make sure all is properly initialized
m_AdminService = std::make_unique<HttpAdminService>(
m_GcScheduler,
*m_JobQueue,
m_CacheStore.Get(),
[this]() { Flush(); },
- HttpAdminService::LogPaths{.AbsLogPath = ServerOptions.AbsLogFile,
+ HttpAdminService::LogPaths{.AbsLogPath = ServerOptions.LoggingConfig.AbsLogFile,
.HttpLogPath = ServerOptions.DataDir / "logs" / "http.log",
.CacheLogPath = ServerOptions.DataDir / "logs" / "z$.log"},
ServerOptions);
@@ -689,6 +723,15 @@ ZenStorageServer::Run()
ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", GetCurrentProcessId());
+ if (m_FrontendService)
+ {
+ ZEN_INFO("frontend link: {}", m_Http->GetServiceUri(m_FrontendService.get()));
+ }
+ else
+ {
+ ZEN_INFO("frontend service disabled");
+ }
+
#if ZEN_PLATFORM_WINDOWS
if (zen::windows::IsRunningOnWine())
{
@@ -796,6 +839,8 @@ ZenStorageServer::Cleanup()
m_IoRunner.join();
}
+ ShutdownServices();
+
if (m_Http)
{
m_Http->Close();
@@ -811,6 +856,10 @@ ZenStorageServer::Cleanup()
Flush();
+#if ZEN_WITH_COMPUTE_SERVICES
+ m_HttpComputeService.reset();
+#endif
+
m_AdminService.reset();
m_VfsService.reset();
m_VfsServiceImpl.reset();
@@ -826,6 +875,8 @@ ZenStorageServer::Cleanup()
m_UpstreamCache.reset();
m_CacheStore = {};
+ m_HttpSessionsService.reset();
+ m_SessionsService.reset();
m_HttpWorkspacesService.reset();
m_Workspaces.reset();
m_HttpProjectService.reset();
diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h
index 5ccb587d6..d625f869c 100644
--- a/src/zenserver/storage/zenstorageserver.h
+++ b/src/zenserver/storage/zenstorageserver.h
@@ -6,11 +6,13 @@
#include <zenhttp/auth/authmgr.h>
#include <zenhttp/auth/authservice.h>
+#include <zenhttp/httpapiservice.h>
#include <zenhttp/httptest.h>
#include <zenstore/cache/structuredcachestore.h>
#include <zenstore/gc.h>
#include <zenstore/projectstore.h>
+#include "../sessions/httpsessions.h"
#include "admin/admin.h"
#include "buildstore/httpbuildstore.h"
#include "cache/httpstructuredcache.h"
@@ -23,6 +25,10 @@
#include "vfs/vfsservice.h"
#include "workspaces/httpworkspaces.h"
+#if ZEN_WITH_COMPUTE_SERVICES
+# include <zencompute/httpcomputeservice.h>
+#endif
+
namespace zen {
class ZenStorageServer : public ZenServerBase
@@ -34,11 +40,6 @@ public:
ZenStorageServer();
~ZenStorageServer();
- void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; }
- void SetTestMode(bool State) { m_TestMode = State; }
- void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; }
- void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; }
-
int Initialize(const ZenStorageServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry);
void Run();
void Cleanup();
@@ -48,14 +49,9 @@ private:
void InitializeStructuredCache(const ZenStorageServerConfig& ServerOptions);
void Flush();
- bool m_IsDedicatedMode = false;
- bool m_TestMode = false;
- bool m_DebugOptionForcedCrash = false;
- std::string m_StartupScrubOptions;
- CbObject m_RootManifest;
- std::filesystem::path m_DataRoot;
- std::filesystem::path m_ContentRoot;
- asio::steady_timer m_StateMarkerTimer{m_IoContext};
+ std::string m_StartupScrubOptions;
+ CbObject m_RootManifest;
+ asio::steady_timer m_StateMarkerTimer{m_IoContext};
void EnqueueStateMarkerTimer();
void CheckStateMarker();
@@ -67,7 +63,6 @@ private:
void InitializeServices(const ZenStorageServerConfig& ServerOptions);
void RegisterServices();
- HttpStatsService m_StatsService;
std::unique_ptr<JobQueue> m_JobQueue;
GcManager m_GcManager;
GcScheduler m_GcScheduler{m_GcManager};
@@ -87,6 +82,8 @@ private:
std::unique_ptr<HttpProjectService> m_HttpProjectService;
std::unique_ptr<Workspaces> m_Workspaces;
std::unique_ptr<HttpWorkspacesService> m_HttpWorkspacesService;
+ std::unique_ptr<SessionsService> m_SessionsService;
+ std::unique_ptr<HttpSessionsService> m_HttpSessionsService;
std::unique_ptr<UpstreamCache> m_UpstreamCache;
std::unique_ptr<HttpUpstreamService> m_UpstreamService;
std::unique_ptr<HttpStructuredCacheService> m_StructuredCacheService;
@@ -95,6 +92,11 @@ private:
std::unique_ptr<HttpBuildStoreService> m_BuildStoreService;
std::unique_ptr<VfsService> m_VfsService;
std::unique_ptr<HttpAdminService> m_AdminService;
+ std::unique_ptr<HttpApiService> m_ApiService;
+
+#if ZEN_WITH_COMPUTE_SERVICES
+ std::unique_ptr<compute::HttpComputeService> m_HttpComputeService;
+#endif
};
struct ZenStorageServerConfigurator;
diff --git a/src/zenserver/trace/tracerecorder.cpp b/src/zenserver/trace/tracerecorder.cpp
new file mode 100644
index 000000000..5dec20e18
--- /dev/null
+++ b/src/zenserver/trace/tracerecorder.cpp
@@ -0,0 +1,565 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "tracerecorder.h"
+
+#include <zencore/basicfile.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/uid.h>
+
+#include <asio.hpp>
+
+#include <atomic>
+#include <cstring>
+#include <memory>
+#include <mutex>
+#include <thread>
+
+namespace zen {
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct TraceSession : public std::enable_shared_from_this<TraceSession>
+{
+ TraceSession(asio::ip::tcp::socket&& Socket, const std::filesystem::path& OutputDir)
+ : m_Socket(std::move(Socket))
+ , m_OutputDir(OutputDir)
+ , m_SessionId(Oid::NewOid())
+ {
+ try
+ {
+ m_RemoteAddress = m_Socket.remote_endpoint().address().to_string();
+ }
+ catch (...)
+ {
+ m_RemoteAddress = "unknown";
+ }
+
+ ZEN_INFO("Trace session {} started from {}", m_SessionId, m_RemoteAddress);
+ }
+
+ ~TraceSession()
+ {
+ if (m_TraceFile.IsOpen())
+ {
+ m_TraceFile.Close();
+ }
+
+ ZEN_INFO("Trace session {} ended, {} bytes recorded to '{}'", m_SessionId, m_TotalBytesRecorded, m_TraceFilePath);
+ }
+
+ void Start() { ReadPreambleHeader(); }
+
+ bool IsActive() const { return m_Socket.is_open(); }
+
+ TraceSessionInfo GetInfo() const
+ {
+ TraceSessionInfo Info;
+ Info.SessionGuid = m_SessionGuid;
+ Info.TraceGuid = m_TraceGuid;
+ Info.ControlPort = m_ControlPort;
+ Info.TransportVersion = m_TransportVersion;
+ Info.ProtocolVersion = m_ProtocolVersion;
+ Info.RemoteAddress = m_RemoteAddress;
+ Info.BytesRecorded = m_TotalBytesRecorded;
+ Info.TraceFilePath = m_TraceFilePath;
+ return Info;
+ }
+
+private:
+ // Preamble format:
+ // [magic: 4 bytes][metadata_size: 2 bytes][metadata fields: variable][version: 2 bytes]
+ //
+ // Magic bytes: [0]=version_char ('2'-'9'), [1]='C', [2]='R', [3]='T'
+ //
+ // Metadata fields (repeated):
+ // [size: 1 byte][id: 1 byte][data: <size> bytes]
+ // Field 0: ControlPort (uint16)
+ // Field 1: SessionGuid (16 bytes)
+ // Field 2: TraceGuid (16 bytes)
+ //
+ // Version: [transport: 1 byte][protocol: 1 byte]
+
+ static constexpr size_t kMagicSize = 4;
+ static constexpr size_t kMetadataSizeFieldSize = 2;
+ static constexpr size_t kPreambleHeaderSize = kMagicSize + kMetadataSizeFieldSize;
+ static constexpr size_t kVersionSize = 2;
+ static constexpr size_t kPreambleBufferSize = 256;
+ static constexpr size_t kReadBufferSize = 64 * 1024;
+
+ void ReadPreambleHeader()
+ {
+ auto Self = shared_from_this();
+
+ // Read the first 6 bytes: 4 magic + 2 metadata size
+ asio::async_read(m_Socket,
+ asio::buffer(m_PreambleBuffer, kPreambleHeaderSize),
+ [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) {
+ if (Ec)
+ {
+ HandleReadError("preamble header", Ec);
+ return;
+ }
+
+ if (!ValidateMagic())
+ {
+ ZEN_WARN("Trace session {}: invalid trace magic header", m_SessionId);
+ CloseSocket();
+ return;
+ }
+
+ ReadPreambleMetadata();
+ });
+ }
+
+ bool ValidateMagic()
+ {
+ const uint8_t* Cursor = m_PreambleBuffer;
+
+ // Validate magic: bytes are version, 'C', 'R', 'T'
+ if (Cursor[3] != 'T' || Cursor[2] != 'R' || Cursor[1] != 'C')
+ {
+ return false;
+ }
+
+ if (Cursor[0] < '2' || Cursor[0] > '9')
+ {
+ return false;
+ }
+
+ // Extract the metadata fields size (does not include the trailing version bytes)
+ std::memcpy(&m_MetadataFieldsSize, Cursor + kMagicSize, sizeof(m_MetadataFieldsSize));
+
+ if (m_MetadataFieldsSize + kVersionSize > kPreambleBufferSize - kPreambleHeaderSize)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ void ReadPreambleMetadata()
+ {
+ auto Self = shared_from_this();
+ size_t ReadSize = m_MetadataFieldsSize + kVersionSize;
+
+ // Read metadata fields + 2 version bytes
+ asio::async_read(m_Socket,
+ asio::buffer(m_PreambleBuffer + kPreambleHeaderSize, ReadSize),
+ [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) {
+ if (Ec)
+ {
+ HandleReadError("preamble metadata", Ec);
+ return;
+ }
+
+ if (!ParseMetadata())
+ {
+ ZEN_WARN("Trace session {}: malformed trace metadata", m_SessionId);
+ CloseSocket();
+ return;
+ }
+
+ if (!CreateTraceFile())
+ {
+ CloseSocket();
+ return;
+ }
+
+ // Write the full preamble to the trace file so it remains a valid .utrace
+ size_t PreambleSize = kPreambleHeaderSize + m_MetadataFieldsSize + kVersionSize;
+ std::error_code WriteEc;
+ m_TraceFile.Write(m_PreambleBuffer, PreambleSize, 0, WriteEc);
+
+ if (WriteEc)
+ {
+ ZEN_ERROR("Trace session {}: failed to write preamble: {}", m_SessionId, WriteEc.message());
+ CloseSocket();
+ return;
+ }
+
+ m_TotalBytesRecorded = PreambleSize;
+
+ ZEN_INFO("Trace session {}: metadata - TransportV{} ProtocolV{} ControlPort:{} SessionGuid:{} TraceGuid:{}",
+ m_SessionId,
+ m_TransportVersion,
+ m_ProtocolVersion,
+ m_ControlPort,
+ m_SessionGuid,
+ m_TraceGuid);
+
+ // Begin streaming trace data to disk
+ ReadMore();
+ });
+ }
+
+ bool ParseMetadata()
+ {
+ const uint8_t* Cursor = m_PreambleBuffer + kPreambleHeaderSize;
+ int32_t Remaining = static_cast<int32_t>(m_MetadataFieldsSize);
+
+ while (Remaining >= 2)
+ {
+ uint8_t FieldSize = Cursor[0];
+ uint8_t FieldId = Cursor[1];
+ Cursor += 2;
+ Remaining -= 2;
+
+ if (Remaining < FieldSize)
+ {
+ return false;
+ }
+
+ switch (FieldId)
+ {
+ case 0: // ControlPort
+ if (FieldSize >= sizeof(uint16_t))
+ {
+ std::memcpy(&m_ControlPort, Cursor, sizeof(uint16_t));
+ }
+ break;
+ case 1: // SessionGuid
+ if (FieldSize >= sizeof(Guid))
+ {
+ std::memcpy(&m_SessionGuid, Cursor, sizeof(Guid));
+ }
+ break;
+ case 2: // TraceGuid
+ if (FieldSize >= sizeof(Guid))
+ {
+ std::memcpy(&m_TraceGuid, Cursor, sizeof(Guid));
+ }
+ break;
+ }
+
+ Cursor += FieldSize;
+ Remaining -= FieldSize;
+ }
+
+ // Metadata should be fully consumed
+ if (Remaining != 0)
+ {
+ return false;
+ }
+
+ // Version bytes follow immediately after the metadata fields
+ const uint8_t* VersionPtr = m_PreambleBuffer + kPreambleHeaderSize + m_MetadataFieldsSize;
+ m_TransportVersion = VersionPtr[0];
+ m_ProtocolVersion = VersionPtr[1];
+
+ return true;
+ }
+
+ bool CreateTraceFile()
+ {
+ m_TraceFilePath = m_OutputDir / fmt::format("{}.utrace", m_SessionId);
+
+ try
+ {
+ m_TraceFile.Open(m_TraceFilePath, BasicFile::Mode::kTruncate);
+ ZEN_INFO("Trace session {} writing to '{}'", m_SessionId, m_TraceFilePath);
+ return true;
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("Trace session {}: failed to create trace file '{}': {}", m_SessionId, m_TraceFilePath, Ex.what());
+ return false;
+ }
+ }
+
+ void ReadMore()
+ {
+ auto Self = shared_from_this();
+
+ m_Socket.async_read_some(asio::buffer(m_ReadBuffer, kReadBufferSize),
+ [this, Self](const asio::error_code& Ec, std::size_t BytesRead) {
+ if (!Ec)
+ {
+ if (BytesRead > 0 && m_TraceFile.IsOpen())
+ {
+ std::error_code WriteEc;
+ const uint64_t FileOffset = m_TotalBytesRecorded;
+ m_TraceFile.Write(m_ReadBuffer, BytesRead, FileOffset, WriteEc);
+
+ if (WriteEc)
+ {
+ ZEN_ERROR("Trace session {}: write error: {}", m_SessionId, WriteEc.message());
+ CloseSocket();
+ return;
+ }
+
+ m_TotalBytesRecorded += BytesRead;
+ }
+
+ ReadMore();
+ }
+ else if (Ec == asio::error::eof)
+ {
+ ZEN_DEBUG("Trace session {} connection closed by peer", m_SessionId);
+ CloseSocket();
+ }
+ else if (Ec == asio::error::operation_aborted)
+ {
+ ZEN_DEBUG("Trace session {} operation aborted", m_SessionId);
+ }
+ else
+ {
+ ZEN_WARN("Trace session {} read error: {}", m_SessionId, Ec.message());
+ CloseSocket();
+ }
+ });
+ }
+
+ void HandleReadError(const char* Phase, const asio::error_code& Ec)
+ {
+ if (Ec == asio::error::eof)
+ {
+ ZEN_DEBUG("Trace session {}: connection closed during {}", m_SessionId, Phase);
+ }
+ else if (Ec == asio::error::operation_aborted)
+ {
+ ZEN_DEBUG("Trace session {}: operation aborted during {}", m_SessionId, Phase);
+ }
+ else
+ {
+ ZEN_WARN("Trace session {}: error during {}: {}", m_SessionId, Phase, Ec.message());
+ }
+
+ CloseSocket();
+ }
+
+ void CloseSocket()
+ {
+ std::error_code Ec;
+ m_Socket.close(Ec);
+
+ if (m_TraceFile.IsOpen())
+ {
+ m_TraceFile.Close();
+ }
+ }
+
+ asio::ip::tcp::socket m_Socket;
+ std::filesystem::path m_OutputDir;
+ std::filesystem::path m_TraceFilePath;
+ BasicFile m_TraceFile;
+ Oid m_SessionId;
+ std::string m_RemoteAddress;
+
+ // Preamble parsing
+ uint8_t m_PreambleBuffer[kPreambleBufferSize] = {};
+ uint16_t m_MetadataFieldsSize = 0;
+
+ // Extracted metadata
+ Guid m_SessionGuid{};
+ Guid m_TraceGuid{};
+ uint16_t m_ControlPort = 0;
+ uint8_t m_TransportVersion = 0;
+ uint8_t m_ProtocolVersion = 0;
+
+ // Streaming
+ uint8_t m_ReadBuffer[kReadBufferSize];
+ uint64_t m_TotalBytesRecorded = 0;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct TraceRecorder::Impl
+{
+ Impl() : m_IoContext(), m_Acceptor(m_IoContext) {}
+
+ ~Impl() { Shutdown(); }
+
+ void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir)
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+
+ if (m_IsRunning)
+ {
+ ZEN_WARN("TraceRecorder already initialized");
+ return;
+ }
+
+ m_OutputDir = OutputDir;
+
+ try
+ {
+ // Create output directory if it doesn't exist
+ CreateDirectories(m_OutputDir);
+
+ // Configure acceptor
+ m_Acceptor.open(asio::ip::tcp::v4());
+ m_Acceptor.set_option(asio::socket_base::reuse_address(true));
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::tcp::v4(), InPort));
+ m_Acceptor.listen();
+
+ m_Port = m_Acceptor.local_endpoint().port();
+
+ ZEN_INFO("TraceRecorder listening on port {}, output directory: '{}'", m_Port, m_OutputDir);
+
+ m_IsRunning = true;
+
+ // Start accepting connections
+ StartAccept();
+
+ // Start IO thread
+ m_IoThread = std::thread([this]() {
+ try
+ {
+ m_IoContext.run();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("TraceRecorder IO thread exception: {}", Ex.what());
+ }
+ });
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("Failed to initialize TraceRecorder: {}", Ex.what());
+ m_IsRunning = false;
+ throw;
+ }
+ }
+
+ void Shutdown()
+ {
+ std::lock_guard<std::mutex> Lock(m_Mutex);
+
+ if (!m_IsRunning)
+ {
+ return;
+ }
+
+ ZEN_INFO("TraceRecorder shutting down");
+
+ m_IsRunning = false;
+
+ std::error_code Ec;
+ m_Acceptor.close(Ec);
+
+ m_IoContext.stop();
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+
+ {
+ std::lock_guard<std::mutex> SessionLock(m_SessionsMutex);
+ m_Sessions.clear();
+ }
+
+ ZEN_INFO("TraceRecorder shutdown complete");
+ }
+
+ bool IsRunning() const { return m_IsRunning; }
+
+ uint16_t GetPort() const { return m_Port; }
+
+ std::vector<TraceSessionInfo> GetActiveSessions() const
+ {
+ std::lock_guard<std::mutex> Lock(m_SessionsMutex);
+
+ std::vector<TraceSessionInfo> Result;
+ for (const auto& WeakSession : m_Sessions)
+ {
+ if (auto Session = WeakSession.lock())
+ {
+ if (Session->IsActive())
+ {
+ Result.push_back(Session->GetInfo());
+ }
+ }
+ }
+ return Result;
+ }
+
+private:
+ void StartAccept()
+ {
+ auto Socket = std::make_shared<asio::ip::tcp::socket>(m_IoContext);
+
+ m_Acceptor.async_accept(*Socket, [this, Socket](const asio::error_code& Ec) {
+ if (!Ec)
+ {
+ auto Session = std::make_shared<TraceSession>(std::move(*Socket), m_OutputDir);
+
+ {
+ std::lock_guard<std::mutex> Lock(m_SessionsMutex);
+
+ // Prune expired sessions while adding the new one
+ std::erase_if(m_Sessions, [](const std::weak_ptr<TraceSession>& Wp) { return Wp.expired(); });
+ m_Sessions.push_back(Session);
+ }
+
+ Session->Start();
+ }
+ else if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("Accept error: {}", Ec.message());
+ }
+
+ // Continue accepting if still running
+ if (m_IsRunning)
+ {
+ StartAccept();
+ }
+ });
+ }
+
+ asio::io_context m_IoContext;
+ asio::ip::tcp::acceptor m_Acceptor;
+ std::thread m_IoThread;
+ std::filesystem::path m_OutputDir;
+ std::mutex m_Mutex;
+ std::atomic<bool> m_IsRunning{false};
+ uint16_t m_Port = 0;
+
+ mutable std::mutex m_SessionsMutex;
+ std::vector<std::weak_ptr<TraceSession>> m_Sessions;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+TraceRecorder::TraceRecorder() : m_Impl(std::make_unique<Impl>())
+{
+}
+
+TraceRecorder::~TraceRecorder()
+{
+ Shutdown();
+}
+
+void
+TraceRecorder::Initialize(uint16_t InPort, const std::filesystem::path& OutputDir)
+{
+ m_Impl->Initialize(InPort, OutputDir);
+}
+
+void
+TraceRecorder::Shutdown()
+{
+ m_Impl->Shutdown();
+}
+
+bool
+TraceRecorder::IsRunning() const
+{
+ return m_Impl->IsRunning();
+}
+
+uint16_t
+TraceRecorder::GetPort() const
+{
+ return m_Impl->GetPort();
+}
+
+std::vector<TraceSessionInfo>
+TraceRecorder::GetActiveSessions() const
+{
+ return m_Impl->GetActiveSessions();
+}
+
+} // namespace zen
diff --git a/src/zenserver/trace/tracerecorder.h b/src/zenserver/trace/tracerecorder.h
new file mode 100644
index 000000000..48857aec8
--- /dev/null
+++ b/src/zenserver/trace/tracerecorder.h
@@ -0,0 +1,46 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/guid.h>
+#include <zencore/zencore.h>
+
+#include <filesystem>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace zen {
+
+struct TraceSessionInfo
+{
+ Guid SessionGuid{};
+ Guid TraceGuid{};
+ uint16_t ControlPort = 0;
+ uint8_t TransportVersion = 0;
+ uint8_t ProtocolVersion = 0;
+ std::string RemoteAddress;
+ uint64_t BytesRecorded = 0;
+ std::filesystem::path TraceFilePath;
+};
+
+class TraceRecorder
+{
+public:
+ TraceRecorder();
+ ~TraceRecorder();
+
+ void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir);
+ void Shutdown();
+
+ bool IsRunning() const;
+ uint16_t GetPort() const;
+
+ std::vector<TraceSessionInfo> GetActiveSessions() const;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua
index 6ee80dc62..f2ed17f05 100644
--- a/src/zenserver/xmake.lua
+++ b/src/zenserver/xmake.lua
@@ -2,7 +2,11 @@
target("zenserver")
set_kind("binary")
+ if enable_unity then
+ add_rules("c++.unity_build", {batchsize = 4})
+ end
add_deps("zencore",
+ "zencompute",
"zenhttp",
"zennet",
"zenremotestore",
@@ -15,6 +19,12 @@ target("zenserver")
add_files("**.cpp")
add_files("frontend/*.zip")
add_files("zenserver.cpp", {unity_ignored = true })
+
+ if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then
+ -- GCC false positives in deeply inlined code (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137)
+ add_files("storage/projectstore/httpprojectstore.cpp", {force = {cxxflags = "-Wno-stringop-overflow"} })
+ add_files("storage/storageconfig.cpp", {force = {cxxflags = "-Wno-array-bounds"} })
+ end
add_includedirs(".")
set_symbols("debug")
@@ -23,6 +33,8 @@ target("zenserver")
add_packages("json11")
add_packages("lua")
add_packages("consul")
+ add_packages("oidctoken")
+ add_packages("nomad")
if has_config("zenmimalloc") then
add_packages("mimalloc")
@@ -32,6 +44,14 @@ target("zenserver")
add_packages("sentry-native")
end
+ if has_config("zenhorde") then
+ add_deps("zenhorde")
+ end
+
+ if has_config("zennomad") then
+ add_deps("zennomad")
+ end
+
if is_mode("release") then
set_optimize("fastest")
end
@@ -141,4 +161,24 @@ target("zenserver")
end
copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin)
end
+
+ local oidctoken_pkg = target:pkg("oidctoken")
+ if oidctoken_pkg then
+ local installdir = oidctoken_pkg:installdir()
+ local oidctoken_bin = "OidcToken"
+ if is_plat("windows") then
+ oidctoken_bin = "OidcToken.exe"
+ end
+ copy_if_newer(path.join(installdir, "bin", oidctoken_bin), path.join(target:targetdir(), oidctoken_bin), oidctoken_bin)
+ end
+
+ local nomad_pkg = target:pkg("nomad")
+ if nomad_pkg then
+ local installdir = nomad_pkg:installdir()
+ local nomad_bin = "nomad"
+ if is_plat("windows") then
+ nomad_bin = "nomad.exe"
+ end
+ copy_if_newer(path.join(installdir, "bin", nomad_bin), path.join(target:targetdir(), nomad_bin), nomad_bin)
+ end
end)
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
index 2bafeeaa1..bb6b02d21 100644
--- a/src/zenserver/zenserver.cpp
+++ b/src/zenserver/zenserver.cpp
@@ -18,11 +18,13 @@
#include <zencore/sentryintegration.h>
#include <zencore/session.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/thread.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zencore/workthreadpool.h>
#include <zenhttp/httpserver.h>
+#include <zenhttp/security/passwordsecurityfilter.h>
#include <zentelemetry/otlptrace.h>
#include <zenutil/service.h>
#include <zenutil/workerpools.h>
@@ -44,6 +46,20 @@ ZEN_THIRD_PARTY_INCLUDES_END
//////////////////////////////////////////////////////////////////////////
+#ifndef ZEN_WITH_COMPUTE_SERVICES
+# define ZEN_WITH_COMPUTE_SERVICES 0
+#endif
+
+#ifndef ZEN_WITH_HORDE
+# define ZEN_WITH_HORDE 0
+#endif
+
+#ifndef ZEN_WITH_NOMAD
+# define ZEN_WITH_NOMAD 0
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
#include "config/config.h"
#include "diag/logging.h"
@@ -142,8 +158,18 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState::
ZEN_INFO("Effective concurrency: {} (hw: {})", GetHardwareConcurrency(), std::thread::hardware_concurrency());
+ InitializeSecuritySettings(ServerOptions);
+
+ if (ServerOptions.LieCpu)
+ {
+ SetCpuCountForReporting(ServerOptions.LieCpu);
+
+ ZEN_INFO("Reporting concurrency: {}", ServerOptions.LieCpu);
+ }
+
m_StatusService.RegisterHandler("status", *this);
m_Http->RegisterService(m_StatusService);
+ m_Http->RegisterService(m_StatsService);
m_StatsReporter.Initialize(ServerOptions.StatsConfig);
if (ServerOptions.StatsConfig.Enabled)
@@ -151,10 +177,37 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState::
EnqueueStatsReportingTimer();
}
- m_HealthService.SetHealthInfo({.DataRoot = ServerOptions.DataDir,
- .AbsLogPath = ServerOptions.AbsLogFile,
- .HttpServerClass = std::string(ServerOptions.HttpConfig.ServerClass),
- .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)});
+ // clang-format off
+ HealthServiceInfo HealthInfo {
+ .DataRoot = ServerOptions.DataDir,
+ .AbsLogPath = ServerOptions.LoggingConfig.AbsLogFile,
+ .HttpServerClass = std::string(ServerOptions.HttpConfig.ServerClass),
+ .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL),
+ .Port = EffectiveBasePort,
+ .Pid = GetCurrentProcessId(),
+ .IsDedicated = ServerOptions.IsDedicated,
+ .StartTimeMs = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch()).count(),
+ .BuildOptions = {
+ {"ZEN_ADDRESS_SANITIZER", ZEN_ADDRESS_SANITIZER != 0},
+ {"ZEN_USE_SENTRY", ZEN_USE_SENTRY != 0},
+ {"ZEN_WITH_TESTS", ZEN_WITH_TESTS != 0},
+ {"ZEN_USE_MIMALLOC", ZEN_USE_MIMALLOC != 0},
+ {"ZEN_USE_RPMALLOC", ZEN_USE_RPMALLOC != 0},
+ {"ZEN_WITH_HTTPSYS", ZEN_WITH_HTTPSYS != 0},
+ {"ZEN_WITH_MEMTRACK", ZEN_WITH_MEMTRACK != 0},
+ {"ZEN_WITH_TRACE", ZEN_WITH_TRACE != 0},
+ {"ZEN_WITH_COMPUTE_SERVICES", ZEN_WITH_COMPUTE_SERVICES != 0},
+ {"ZEN_WITH_HORDE", ZEN_WITH_HORDE != 0},
+ {"ZEN_WITH_NOMAD", ZEN_WITH_NOMAD != 0},
+ },
+ .RuntimeConfig = BuildSettingsList(ServerOptions),
+ };
+ // clang-format on
+
+ HealthInfo.RuntimeConfig.emplace(HealthInfo.RuntimeConfig.begin() + 2, "EffectivePort"sv, fmt::to_string(EffectiveBasePort));
+
+ m_HealthService.SetHealthInfo(std::move(HealthInfo));
LogSettingsSummary(ServerOptions);
@@ -164,12 +217,23 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState::
void
ZenServerBase::Finalize()
{
+ m_StatsService.RegisterHandler("http", *m_Http);
+
+ m_Http->SetDefaultRedirect("/dashboard/");
+
// Register health service last so if we return "OK" for health it means all services have been properly initialized
m_Http->RegisterService(m_HealthService);
}
void
+ZenServerBase::ShutdownServices()
+{
+ m_StatsService.UnregisterHandler("http", *m_Http);
+ m_StatsService.Shutdown();
+}
+
+void
ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) const
{
ZEN_MEMSCOPE(GetZenserverTag());
@@ -375,46 +439,65 @@ ZenServerBase::CheckSigInt()
void
ZenServerBase::HandleStatusRequest(HttpServerRequest& Request)
{
+ auto Metrics = m_MetricsTracker.Query();
+
CbObjectWriter Cbo;
Cbo << "ok" << true;
Cbo << "state" << ToString(m_CurrentState);
+ Cbo << "hostname" << GetMachineName();
+ Cbo << "cpuUsagePercent" << Metrics.CpuUsagePercent;
Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
}
-void
-ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig)
+std::vector<std::pair<std::string_view, std::string>>
+ZenServerBase::BuildSettingsList(const ZenServerConfig& ServerConfig)
{
// clang-format off
- std::list<std::pair<std::string_view, std::string>> Settings = {
- {"DataDir"sv, ServerConfig.DataDir.string()},
- {"AbsLogFile"sv, ServerConfig.AbsLogFile.string()},
- {"SystemRootDir"sv, ServerConfig.SystemRootDir.string()},
- {"ContentDir"sv, ServerConfig.ContentDir.string()},
+ std::vector<std::pair<std::string_view, std::string>> Settings = {
+ {"SystemRootDir"sv, fmt::format("{}", ServerConfig.SystemRootDir)},
+ {"ContentDir"sv, fmt::format("{}", ServerConfig.ContentDir)},
{"BasePort"sv, fmt::to_string(ServerConfig.BasePort)},
+ {"CoreLimit"sv, fmt::to_string(ServerConfig.CoreLimit)},
{"IsDebug"sv, fmt::to_string(ServerConfig.IsDebug)},
{"IsCleanStart"sv, fmt::to_string(ServerConfig.IsCleanStart)},
{"IsPowerCycle"sv, fmt::to_string(ServerConfig.IsPowerCycle)},
{"IsTest"sv, fmt::to_string(ServerConfig.IsTest)},
{"Detach"sv, fmt::to_string(ServerConfig.Detach)},
- {"NoConsoleOutput"sv, fmt::to_string(ServerConfig.NoConsoleOutput)},
- {"QuietConsole"sv, fmt::to_string(ServerConfig.QuietConsole)},
- {"CoreLimit"sv, fmt::to_string(ServerConfig.CoreLimit)},
- {"IsDedicated"sv, fmt::to_string(ServerConfig.IsDedicated)},
- {"ShouldCrash"sv, fmt::to_string(ServerConfig.ShouldCrash)},
+ {"NoConsoleOutput"sv, fmt::to_string(ServerConfig.LoggingConfig.NoConsoleOutput)},
+ {"QuietConsole"sv, fmt::to_string(ServerConfig.LoggingConfig.QuietConsole)},
{"ChildId"sv, ServerConfig.ChildId},
- {"LogId"sv, ServerConfig.LogId},
+ {"LogId"sv, ServerConfig.LoggingConfig.LogId},
{"Sentry DSN"sv, ServerConfig.SentryConfig.Dsn.empty() ? "not set" : ServerConfig.SentryConfig.Dsn},
{"Sentry Environment"sv, ServerConfig.SentryConfig.Environment},
{"Statsd Enabled"sv, fmt::to_string(ServerConfig.StatsConfig.Enabled)},
+ {"SecurityConfigPath"sv, fmt::format("{}", ServerConfig.SecurityConfigPath)},
};
// clang-format on
if (ServerConfig.StatsConfig.Enabled)
{
- Settings.emplace_back("Statsd Host", ServerConfig.StatsConfig.StatsdHost);
- Settings.emplace_back("Statsd Port", fmt::to_string(ServerConfig.StatsConfig.StatsdPort));
+ Settings.emplace_back("Statsd Host"sv, ServerConfig.StatsConfig.StatsdHost);
+ Settings.emplace_back("Statsd Port"sv, fmt::to_string(ServerConfig.StatsConfig.StatsdPort));
}
+ return Settings;
+}
+
+void
+ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig)
+{
+ auto Settings = BuildSettingsList(ServerConfig);
+
+ // Log-only entries not needed in RuntimeConfig
+ // clang-format off
+ Settings.insert(Settings.begin(), {
+ {"DataDir"sv, fmt::format("{}", ServerConfig.DataDir)},
+ {"AbsLogFile"sv, fmt::format("{}", ServerConfig.LoggingConfig.AbsLogFile)},
+ });
+ // clang-format on
+ Settings.emplace_back("IsDedicated"sv, fmt::to_string(ServerConfig.IsDedicated));
+ Settings.emplace_back("ShouldCrash"sv, fmt::to_string(ServerConfig.ShouldCrash));
+
size_t MaxWidth = 0;
for (const auto& Setting : Settings)
{
@@ -432,6 +515,44 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig)
}
}
+void
+ZenServerBase::InitializeSecuritySettings(const ZenServerConfig& ServerOptions)
+{
+ ZEN_ASSERT(m_Http);
+
+ if (!ServerOptions.SecurityConfigPath.empty())
+ {
+ IoBuffer SecurityJson = ReadFile(ServerOptions.SecurityConfigPath).Flatten();
+ std::string_view Json(reinterpret_cast<const char*>(SecurityJson.GetData()), SecurityJson.GetSize());
+ std::string JsonError;
+ CbObject SecurityConfig = LoadCompactBinaryFromJson(Json, JsonError).AsObject();
+ if (!JsonError.empty())
+ {
+ throw std::runtime_error(
+ fmt::format("Invalid security configuration file at {}. '{}'", ServerOptions.SecurityConfigPath, JsonError));
+ }
+
+ CbObjectView HttpRootFilterConfig = SecurityConfig["http"sv].AsObjectView()["root"sv].AsObjectView()["filter"sv].AsObjectView();
+ if (HttpRootFilterConfig)
+ {
+ std::string_view FilterType = HttpRootFilterConfig["type"sv].AsString();
+ if (FilterType == PasswordHttpFilter::TypeName)
+ {
+ PasswordHttpFilter::Configuration Config =
+ PasswordHttpFilter::ReadConfiguration(HttpRootFilterConfig["config"].AsObjectView());
+ m_HttpRequestFilter = std::make_unique<PasswordHttpFilter>(Config);
+ m_Http->SetHttpRequestFilter(m_HttpRequestFilter.get());
+ }
+ else
+ {
+ throw std::runtime_error(fmt::format("Security configuration file at {} references unknown http root filter type '{}'",
+ ServerOptions.SecurityConfigPath,
+ FilterType));
+ }
+ }
+ }
+}
+
//////////////////////////////////////////////////////////////////////////
ZenServerMain::ZenServerMain(ZenServerConfig& ServerOptions) : m_ServerOptions(ServerOptions)
@@ -467,7 +588,7 @@ ZenServerMain::Run()
ZEN_OTEL_SPAN("SentryInit");
std::string SentryDatabasePath = (m_ServerOptions.DataDir / ".sentry-native").string();
- std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string();
+ std::string SentryAttachmentPath = m_ServerOptions.LoggingConfig.AbsLogFile.string();
Sentry.Initialize({.DatabasePath = SentryDatabasePath,
.AttachmentsPath = SentryAttachmentPath,
@@ -567,6 +688,8 @@ ZenServerMain::Run()
{
ZEN_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message());
Sleep(500);
+
+ m_LockFile.Create(LockFilePath, MakeLockData(false), Ec);
if (Ec)
{
ZEN_WARN(ZEN_APP_NAME " exiting, unable to grab lock at '{}' (reason: '{}')", LockFilePath, Ec.message());
@@ -622,6 +745,10 @@ ZenServerMain::Run()
RequestApplicationExit(1);
}
+#if ZEN_USE_SENTRY
+ Sentry.Close();
+#endif
+
ShutdownServerLogging();
ReportServiceStatus(ServiceStatus::Stopped);
diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h
index ab7122fcc..c06093f0d 100644
--- a/src/zenserver/zenserver.h
+++ b/src/zenserver/zenserver.h
@@ -3,11 +3,13 @@
#pragma once
#include <zencore/basicfile.h>
+#include <zencore/system.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/httpstats.h>
#include <zenhttp/httpstatus.h>
#include <zenutil/zenserverprocess.h>
+#include <atomic>
#include <memory>
#include <string_view>
#include "config/config.h"
@@ -43,11 +45,18 @@ public:
void SetIsReadyFunc(std::function<void()>&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); }
+ void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; }
+ void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; }
+ void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; }
+ void SetTestMode(bool State) { m_TestMode = State; }
+
protected:
int Initialize(const ZenServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry);
void Finalize();
+ void ShutdownServices();
void GetBuildOptions(StringBuilderBase& OutOptions, char Separator = ',') const;
- void LogSettingsSummary(const ZenServerConfig& ServerConfig);
+ static std::vector<std::pair<std::string_view, std::string>> BuildSettingsList(const ZenServerConfig& ServerConfig);
+ void LogSettingsSummary(const ZenServerConfig& ServerConfig);
protected:
NamedMutex m_ServerMutex;
@@ -55,6 +64,10 @@ protected:
bool m_UseSentry = false;
bool m_IsPowerCycle = false;
+ bool m_IsDedicatedMode = false;
+ bool m_TestMode = false;
+ bool m_DebugOptionForcedCrash = false;
+
std::thread m_IoRunner;
asio::io_context m_IoContext;
void EnsureIoRunner();
@@ -64,17 +77,26 @@ protected:
kInitializing,
kRunning,
kShuttingDown
- } m_CurrentState = kInitializing;
+ };
+ std::atomic<ServerState> m_CurrentState = kInitializing;
- inline void SetNewState(ServerState NewState) { m_CurrentState = NewState; }
+ inline void SetNewState(ServerState NewState) { m_CurrentState.store(NewState, std::memory_order_relaxed); }
static std::string_view ToString(ServerState Value);
std::function<void()> m_IsReadyFunc;
void OnReady();
- Ref<HttpServer> m_Http;
- HttpHealthService m_HealthService;
- HttpStatusService m_StatusService;
+ std::filesystem::path m_DataRoot; // Root directory for server state
+ std::filesystem::path m_ContentRoot; // Root directory for frontend content
+
+ Ref<HttpServer> m_Http;
+
+ std::unique_ptr<IHttpRequestFilter> m_HttpRequestFilter;
+
+ HttpHealthService m_HealthService;
+ HttpStatsService m_StatsService{m_IoContext};
+ HttpStatusService m_StatusService;
+ SystemMetricsTracker m_MetricsTracker;
// Stats reporting
@@ -107,8 +129,10 @@ protected:
// IHttpStatusProvider
virtual void HandleStatusRequest(HttpServerRequest& Request) override;
-};
+private:
+ void InitializeSecuritySettings(const ZenServerConfig& ServerOptions);
+};
class ZenServerMain
{
public:
diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc
index e0003ea8f..f353bd9cc 100644
--- a/src/zenserver/zenserver.rc
+++ b/src/zenserver/zenserver.rc
@@ -28,7 +28,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US
// Icon with lowest ID value placed first to ensure application icon
// remains consistent on all systems.
-IDI_ICON1 ICON "..\\UnrealEngine.ico"
+IDI_ICON1 ICON "..\\zen.ico"
#endif // English (United States) resources
/////////////////////////////////////////////////////////////////////////////
diff --git a/src/zenstore-test/zenstore-test.cpp b/src/zenstore-test/zenstore-test.cpp
index c055dbb64..875373a9d 100644
--- a/src/zenstore-test/zenstore-test.cpp
+++ b/src/zenstore-test/zenstore-test.cpp
@@ -1,45 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/trace.h>
+#include <zencore/testing.h>
#include <zenstore/zenstore.h>
#include <zencore/memory/newdelete.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
-
int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zenstore_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zenstore-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
-
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zenstore-test", zen::zenstore_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zenstore/blockstore.cpp b/src/zenstore/blockstore.cpp
index 3ea91ead6..6197c7f24 100644
--- a/src/zenstore/blockstore.cpp
+++ b/src/zenstore/blockstore.cpp
@@ -1556,6 +1556,8 @@ BlockStore::GetMetaData(uint32_t BlockIndex) const
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("store.blockstore");
+
TEST_CASE("blockstore.blockstoredisklocation")
{
BlockStoreLocation Zero = BlockStoreLocation{.BlockIndex = 0, .Offset = 0, .Size = 0};
@@ -2427,6 +2429,8 @@ TEST_CASE("blockstore.BlockStoreFileAppender")
}
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp
index 04a0781d3..dff1c3c61 100644
--- a/src/zenstore/buildstore/buildstore.cpp
+++ b/src/zenstore/buildstore/buildstore.cpp
@@ -266,13 +266,12 @@ BuildStore::PutBlob(const IoHash& BlobHash, const IoBuffer& Payload)
m_BlobLookup.insert({BlobHash, NewBlobIndex});
}
- m_LastAccessTimeUpdateCount++;
if (m_TrackedBlobKeys)
{
m_TrackedBlobKeys->push_back(BlobHash);
if (MetadataHash != IoHash::Zero)
{
- m_TrackedBlobKeys->push_back(BlobHash);
+ m_TrackedBlobKeys->push_back(MetadataHash);
}
}
}
@@ -374,8 +373,8 @@ BuildStore::PutMetadatas(std::span<const IoHash> BlobHashes, std::span<const IoB
CompressedMetadataBuffers.resize(Metadatas.size());
if (OptionalWorkerPool)
{
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
for (size_t Index = 0; Index < Metadatas.size(); Index++)
{
@@ -506,8 +505,8 @@ BuildStore::GetMetadatas(std::span<const IoHash> BlobHashes, WorkerThreadPool* O
else
{
ZEN_WARN("Metadata {} for blob {} is malformed (not a compressed binary format)",
- MetadataHashes[ResultIndex],
- BlobHashes[ResultIndex]);
+ MetadataHashes[Index],
+ BlobHashes[MetaLocationResultIndexes[Index]]);
}
}
}
@@ -562,7 +561,7 @@ BuildStore::GetStorageStats() const
RwLock::SharedLockScope _(m_Lock);
Result.EntryCount = m_BlobLookup.size();
- for (auto LookupIt : m_BlobLookup)
+ for (const auto& LookupIt : m_BlobLookup)
{
const BlobIndex ReadBlobIndex = LookupIt.second;
const BlobEntry& ReadBlobEntry = m_BlobEntries[ReadBlobIndex];
@@ -635,7 +634,7 @@ BuildStore::CompactState()
const size_t MetadataCount = m_MetadataEntries.size();
MetadataEntries.reserve(MetadataCount);
- for (auto LookupIt : m_BlobLookup)
+ for (const auto& LookupIt : m_BlobLookup)
{
const IoHash& BlobHash = LookupIt.first;
const BlobIndex ReadBlobIndex = LookupIt.second;
@@ -956,7 +955,7 @@ BuildStore::WriteAccessTimes(const RwLock::ExclusiveLockScope&, const std::files
std::vector<AccessTimeRecord> AccessRecords;
AccessRecords.reserve(Header.AccessTimeCount);
- for (auto It : m_BlobLookup)
+ for (const auto& It : m_BlobLookup)
{
const IoHash& Key = It.first;
const BlobIndex Index = It.second;
@@ -966,7 +965,7 @@ BuildStore::WriteAccessTimes(const RwLock::ExclusiveLockScope&, const std::files
}
uint64_t RecordsSize = sizeof(AccessTimeRecord) * Header.AccessTimeCount;
TempFile.Write(AccessRecords.data(), RecordsSize, Offset);
- Offset += sizeof(AccessTimesHeader) * Header.AccessTimeCount;
+ Offset += sizeof(AccessTimeRecord) * Header.AccessTimeCount;
}
if (TempFile.MoveTemporaryIntoPlace(AccessTimesPath, Ec); Ec)
{
@@ -1373,6 +1372,8 @@ BuildStore::LockState(GcCtx& Ctx)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("store.buildstore");
+
TEST_CASE("BuildStore.Blobs")
{
ScopedTemporaryDirectory _;
@@ -1822,6 +1823,8 @@ TEST_CASE("BuildStore.SizeLimit")
}
}
+TEST_SUITE_END();
+
void
buildstore_forcelink()
{
diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp
index ead7e4f3a..4640309d9 100644
--- a/src/zenstore/cache/cachedisklayer.cpp
+++ b/src/zenstore/cache/cachedisklayer.cpp
@@ -602,7 +602,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B
if (FileSize < sizeof(BucketMetaHeader))
{
- ZEN_WARN("Failed to read sidecar file '{}'. Minimum size {} expected, actual size: ",
+ ZEN_WARN("Failed to read sidecar file '{}'. Minimum size {} expected, actual size: {}",
SidecarPath,
sizeof(BucketMetaHeader),
FileSize);
@@ -626,7 +626,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B
return false;
}
- const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(BucketMetaHeader))) / sizeof(ManifestData);
+ const uint64_t ExpectedEntryCount = (FileSize - sizeof(BucketMetaHeader)) / sizeof(ManifestData);
if (Header.EntryCount > ExpectedEntryCount)
{
ZEN_WARN(
@@ -654,6 +654,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B
SidecarPath,
sizeof(ManifestData),
CurrentReadOffset);
+ break;
}
CurrentReadOffset += sizeof(ManifestData);
@@ -1011,7 +1012,7 @@ ZenCacheDiskLayer::CacheBucket::WriteIndexSnapshotLocked(uint64_t LogPosi
{
// This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in
// the end it will be the same result
- ZEN_WARN("snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message());
+ ZEN_WARN("snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message());
}
m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite);
}
@@ -1057,7 +1058,7 @@ ZenCacheDiskLayer::CacheBucket::ReadIndexFile(RwLock::ExclusiveLockScope&, const
return 0;
}
- const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(cache::impl::CacheBucketIndexHeader))) / sizeof(DiskIndexEntry);
+ const uint64_t ExpectedEntryCount = (FileSize - sizeof(cache::impl::CacheBucketIndexHeader)) / sizeof(DiskIndexEntry);
if (Header.EntryCount > ExpectedEntryCount)
{
return 0;
@@ -1267,10 +1268,10 @@ ZenCacheDiskLayer::CacheBucket::InitializeIndexFromDisk(RwLock::ExclusiveLockSco
{
RemoveMemCachedData(IndexLock, Payload);
RemoveMetaData(IndexLock, Payload);
+ Location.Flags |= DiskLocation::kTombStone;
+ MissingEntries.push_back(DiskIndexEntry{.Key = It.first, .Location = Location});
}
}
- Location.Flags |= DiskLocation::kTombStone;
- MissingEntries.push_back(DiskIndexEntry{.Key = It.first, .Location = Location});
}
ZEN_ASSERT(!MissingEntries.empty());
@@ -2812,7 +2813,7 @@ ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, c
m_BucketDir,
Ec.message(),
RetriesLeft);
- Sleep(100 - (3 - RetriesLeft) * 100); // Total 600 ms
+ Sleep((3 - RetriesLeft) * 100); // Total 600 ms
Ec.clear();
DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
RetriesLeft--;
@@ -2866,11 +2867,12 @@ ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, c
{
EntryIndex = It.value();
ZEN_ASSERT_SLOW(EntryIndex < PayloadIndex(m_AccessTimes.size()));
- BucketPayload& Payload = m_Payloads[EntryIndex];
- uint64_t OldSize = Payload.Location.Size();
+ BucketPayload& Payload = m_Payloads[EntryIndex];
+ uint64_t OldSize = Payload.Location.Size();
+ RemoveMemCachedData(IndexLock, Payload);
+ RemoveMetaData(IndexLock, Payload);
Payload = BucketPayload{.Location = Loc};
m_AccessTimes[EntryIndex] = GcClock::TickCount();
- RemoveMemCachedData(IndexLock, Payload);
m_StandaloneSize.fetch_sub(OldSize, std::memory_order::relaxed);
}
if ((Value.RawSize != 0 || Value.RawHash != IoHash::Zero) && Value.RawSize <= std::numeric_limits<std::uint32_t>::max())
@@ -3521,7 +3523,7 @@ ZenCacheDiskLayer::CacheBucket::GetReferences(const LoggerRef& Logger,
}
else
{
- ZEN_WARN("Cache record {} payload is malformed. Reason: ", RawHash, ToString(Error));
+ ZEN_WARN("Cache record {} payload is malformed. Reason: {}", RawHash, ToString(Error));
}
return false;
};
@@ -4282,8 +4284,8 @@ ZenCacheDiskLayer::DiscoverBuckets()
RwLock SyncLock;
WorkerThreadPool& Pool = GetLargeWorkerPool(EWorkloadType::Burst);
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -4454,8 +4456,8 @@ ZenCacheDiskLayer::Flush()
}
{
WorkerThreadPool& Pool = GetMediumWorkerPool(EWorkloadType::Burst);
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -4496,8 +4498,8 @@ ZenCacheDiskLayer::Scrub(ScrubContext& Ctx)
RwLock::SharedLockScope _(m_Lock);
- std::atomic<bool> Abort;
- std::atomic<bool> Pause;
+ std::atomic<bool> Abort{false};
+ std::atomic<bool> Pause{false};
ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog);
try
@@ -4559,9 +4561,11 @@ ZenCacheDiskLayer::Stats() const
ZenCacheDiskLayer::Info
ZenCacheDiskLayer::GetInfo() const
{
- ZenCacheDiskLayer::Info Info = {.RootDir = m_RootDir, .Config = m_Configuration};
+ ZenCacheDiskLayer::Info Info;
+ Info.RootDir = m_RootDir;
{
RwLock::SharedLockScope _(m_Lock);
+ Info.Config = m_Configuration;
Info.BucketNames.reserve(m_Buckets.size());
for (auto& Kv : m_Buckets)
{
diff --git a/src/zenstore/cache/cachepolicy.cpp b/src/zenstore/cache/cachepolicy.cpp
index ca8a95ca1..c1e7dc5b3 100644
--- a/src/zenstore/cache/cachepolicy.cpp
+++ b/src/zenstore/cache/cachepolicy.cpp
@@ -284,6 +284,9 @@ CacheRecordPolicyBuilder::Build()
}
#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("store.cachepolicy");
+
TEST_CASE("cachepolicy")
{
SUBCASE("atomics serialization")
@@ -400,13 +403,13 @@ TEST_CASE("cacherecordpolicy")
RecordPolicy.Save(Writer);
CbObject Saved = Writer.Save()->AsObject();
CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get();
- CHECK(!RecordPolicy.IsUniform());
- CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy);
- CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy);
- CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap);
- CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap);
- CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy);
- CHECK(RecordPolicy.GetValuePolicies().size() == 2);
+ CHECK(!Loaded.IsUniform());
+ CHECK(Loaded.GetRecordPolicy() == UnionPolicy);
+ CHECK(Loaded.GetBasePolicy() == DefaultPolicy);
+ CHECK(Loaded.GetValuePolicy(PartialOid) == PartialOverlap);
+ CHECK(Loaded.GetValuePolicy(NoOverlapOid) == NoOverlap);
+ CHECK(Loaded.GetValuePolicy(OtherOid) == DefaultValuePolicy);
+ CHECK(Loaded.GetValuePolicies().size() == 2);
}
}
@@ -416,6 +419,8 @@ TEST_CASE("cacherecordpolicy")
CHECK(Loaded.IsNull());
}
}
+
+TEST_SUITE_END();
#endif
void
diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp
index 94abcf547..90c5a5e60 100644
--- a/src/zenstore/cache/cacherpc.cpp
+++ b/src/zenstore/cache/cacherpc.cpp
@@ -866,8 +866,8 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb
Request.Complete = false;
}
}
- Request.ElapsedTimeUs += Timer.GetElapsedTimeUs();
}
+ Request.ElapsedTimeUs += Timer.GetElapsedTimeUs();
};
m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete));
@@ -934,7 +934,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb
*Namespace,
Key.Bucket,
Key.Hash,
- Request.RecordObject ? ""sv : " (PARTIAL)"sv,
+ Request.RecordObject ? " (PARTIAL)"sv : ""sv,
Request.Source ? Request.Source->Url : "LOCAL"sv,
NiceLatencyNs(Request.ElapsedTimeUs * 1000));
m_CacheStats.MissCount++;
@@ -966,7 +966,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb
}
else
{
- ResponseObject.AddBool(true);
+ ResponseObject.AddBool(false);
}
}
ResponseObject.EndArray();
diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp
index 52b494e45..cff0e9a35 100644
--- a/src/zenstore/cache/structuredcachestore.cpp
+++ b/src/zenstore/cache/structuredcachestore.cpp
@@ -608,7 +608,10 @@ ZenCacheStore::GetBatch::Commit()
m_CacheStore.m_HitCount++;
OpScope.SetBytes(Result.Value.GetSize());
}
- m_CacheStore.m_MissCount++;
+ else
+ {
+ m_CacheStore.m_MissCount++;
+ }
}
}
}
@@ -683,8 +686,8 @@ ZenCacheStore::Get(const CacheRequestContext& Context,
return false;
}
ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get [{}], bucket '{}', key '{}'",
- Context,
Namespace,
+ Context,
Bucket,
HashKey.ToHexString());
@@ -719,8 +722,8 @@ ZenCacheStore::Get(const CacheRequestContext& Context,
}
ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get [{}], bucket '{}', key '{}'",
- Context,
Namespace,
+ Context,
Bucket,
HashKey.ToHexString());
@@ -787,8 +790,8 @@ ZenCacheStore::Put(const CacheRequestContext& Context,
}
ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put [{}] bucket '{}', key '{}'",
- Context,
Namespace,
+ Context,
Bucket,
HashKey.ToHexString());
@@ -813,7 +816,7 @@ ZenCacheStore::DropNamespace(std::string_view InNamespace)
{
std::function<void()> PostDropOp;
{
- RwLock::SharedLockScope _(m_NamespacesLock);
+ RwLock::ExclusiveLockScope _(m_NamespacesLock);
if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end())
{
ZenCacheNamespace& Namespace = *It->second;
@@ -1392,6 +1395,8 @@ namespace testutils {
} // namespace testutils
+TEST_SUITE_BEGIN("store.structuredcachestore");
+
TEST_CASE("cachestore.store")
{
ScopedTemporaryDirectory TempDir;
@@ -1548,7 +1553,7 @@ TEST_CASE("cachestore.size")
}
}
-TEST_CASE("cachestore.threadedinsert") // * doctest::skip(true))
+TEST_CASE("cachestore.threadedinsert" * doctest::skip())
{
// for (uint32_t i = 0; i < 100; ++i)
{
@@ -2741,6 +2746,8 @@ TEST_CASE("cachestore.newgc.basics")
}
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp
index ed017988f..8855c87d8 100644
--- a/src/zenstore/cas.cpp
+++ b/src/zenstore/cas.cpp
@@ -153,7 +153,10 @@ CasImpl::Initialize(const CidStoreConfiguration& InConfig)
}
for (std::future<void>& Result : Work)
{
- Result.get();
+ if (Result.valid())
+ {
+ Result.get();
+ }
}
}
}
@@ -300,12 +303,12 @@ GetCompactCasResults(CasContainerStrategy& Strategy,
};
static void
-GetFileCasResults(FileCasStrategy& Strategy,
- CasStore::InsertMode Mode,
- std::span<IoBuffer> Data,
- std::span<IoHash> ChunkHashes,
- std::span<size_t> Indexes,
- std::vector<CasStore::InsertResult> Results)
+GetFileCasResults(FileCasStrategy& Strategy,
+ CasStore::InsertMode Mode,
+ std::span<IoBuffer> Data,
+ std::span<IoHash> ChunkHashes,
+ std::span<size_t> Indexes,
+ std::vector<CasStore::InsertResult>& Results)
{
for (size_t Index : Indexes)
{
@@ -426,7 +429,7 @@ CasImpl::IterateChunks(std::span<IoHash> DecompressedIds,
[&](size_t Index, const IoBuffer& Payload) {
IoBuffer Chunk(Payload);
Chunk.SetContentType(ZenContentType::kCompressedBinary);
- return AsyncCallback(Index, Payload);
+ return AsyncCallback(Index, Chunk);
},
OptionalWorkerPool,
LargeSizeLimit == 0 ? m_Config.HugeValueThreshold : Min(LargeSizeLimit, m_Config.HugeValueThreshold)))
@@ -439,7 +442,7 @@ CasImpl::IterateChunks(std::span<IoHash> DecompressedIds,
[&](size_t Index, const IoBuffer& Payload) {
IoBuffer Chunk(Payload);
Chunk.SetContentType(ZenContentType::kCompressedBinary);
- return AsyncCallback(Index, Payload);
+ return AsyncCallback(Index, Chunk);
},
OptionalWorkerPool,
LargeSizeLimit == 0 ? m_Config.TinyValueThreshold : Min(LargeSizeLimit, m_Config.TinyValueThreshold)))
@@ -452,7 +455,7 @@ CasImpl::IterateChunks(std::span<IoHash> DecompressedIds,
[&](size_t Index, const IoBuffer& Payload) {
IoBuffer Chunk(Payload);
Chunk.SetContentType(ZenContentType::kCompressedBinary);
- return AsyncCallback(Index, Payload);
+ return AsyncCallback(Index, Chunk);
},
OptionalWorkerPool))
{
@@ -512,6 +515,8 @@ CreateCasStore(GcManager& Gc)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("store.cas");
+
TEST_CASE("CasStore")
{
ScopedTemporaryDirectory TempDir;
@@ -553,6 +558,8 @@ TEST_CASE("CasStore")
CHECK(Lookup2);
}
+TEST_SUITE_END();
+
void
CAS_forcelink()
{
diff --git a/src/zenstore/caslog.cpp b/src/zenstore/caslog.cpp
index 492ce9317..44664dac2 100644
--- a/src/zenstore/caslog.cpp
+++ b/src/zenstore/caslog.cpp
@@ -35,7 +35,7 @@ CasLogFile::~CasLogFile()
}
bool
-CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize)
+CasLogFile::IsValid(const std::filesystem::path& FileName, size_t RecordSize)
{
if (!IsFile(FileName))
{
@@ -71,7 +71,7 @@ CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize)
}
void
-CasLogFile::Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode)
+CasLogFile::Open(const std::filesystem::path& FileName, size_t RecordSize, Mode Mode)
{
m_RecordSize = RecordSize;
@@ -205,7 +205,7 @@ CasLogFile::Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntr
m_File.Read(ReadBuffer.data(), BytesToRead, LogBaseOffset + ReadOffset);
- for (int i = 0; i < int(EntriesToRead); ++i)
+ for (size_t i = 0; i < EntriesToRead; ++i)
{
Handler(ReadBuffer.data() + (i * m_RecordSize));
}
diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp
index bedf91287..b20d8f565 100644
--- a/src/zenstore/cidstore.cpp
+++ b/src/zenstore/cidstore.cpp
@@ -48,13 +48,13 @@ struct CidStore::Impl
std::vector<CidStore::InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes, CidStore::InsertMode Mode)
{
+ ZEN_ASSERT(ChunkDatas.size() == RawHashes.size());
if (ChunkDatas.size() == 1)
{
std::vector<CidStore::InsertResult> Result(1);
Result[0] = AddChunk(ChunkDatas[0], RawHashes[0], Mode);
return Result;
}
- ZEN_ASSERT(ChunkDatas.size() == RawHashes.size());
std::vector<IoBuffer> Chunks;
Chunks.reserve(ChunkDatas.size());
#if ZEN_BUILD_DEBUG
@@ -81,6 +81,7 @@ struct CidStore::Impl
m_CasStore.InsertChunks(Chunks, RawHashes, static_cast<CasStore::InsertMode>(Mode));
ZEN_ASSERT(CasResults.size() == ChunkDatas.size());
std::vector<CidStore::InsertResult> Result;
+ Result.reserve(CasResults.size());
for (const CasStore::InsertResult& CasResult : CasResults)
{
if (CasResult.New)
diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp
index 5d8f95c9e..b09892687 100644
--- a/src/zenstore/compactcas.cpp
+++ b/src/zenstore/compactcas.cpp
@@ -153,7 +153,7 @@ CasContainerStrategy::~CasContainerStrategy()
}
catch (const std::exception& Ex)
{
- ZEN_ERROR("~CasContainerStrategy failed with: ", Ex.what());
+ ZEN_ERROR("~CasContainerStrategy failed with: {}", Ex.what());
}
m_Gc.RemoveGcReferenceStore(*this);
m_Gc.RemoveGcStorage(this);
@@ -440,9 +440,9 @@ CasContainerStrategy::IterateChunks(std::span<const IoHash> ChunkHas
return true;
}
- std::atomic<bool> AbortFlag;
+ std::atomic<bool> AbortFlag{false};
{
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -559,8 +559,8 @@ CasContainerStrategy::ScrubStorage(ScrubContext& Ctx)
std::vector<BlockStoreLocation> ChunkLocations;
std::vector<IoHash> ChunkIndexToChunkHash;
- std::atomic<bool> Abort;
- std::atomic<bool> Pause;
+ std::atomic<bool> Abort{false};
+ std::atomic<bool> Pause{false};
ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog);
try
@@ -1007,7 +1007,7 @@ CasContainerStrategy::CompactIndex(RwLock::ExclusiveLockScope&)
std::vector<BlockStoreDiskLocation> Locations;
Locations.reserve(EntryCount);
LocationMap.reserve(EntryCount);
- for (auto It : m_LocationMap)
+ for (const auto& It : m_LocationMap)
{
size_t EntryIndex = Locations.size();
Locations.push_back(m_Locations[It.second]);
@@ -1106,7 +1106,7 @@ CasContainerStrategy::MakeIndexSnapshot(bool ResetLog)
{
// This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in
// the end it will be the same result
- ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message());
+ ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message());
}
m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite);
}
@@ -1136,7 +1136,7 @@ CasContainerStrategy::ReadIndexFile(const std::filesystem::path& IndexPath, uint
uint64_t Size = ObjectIndexFile.FileSize();
if (Size >= sizeof(CasDiskIndexHeader))
{
- uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CasDiskIndexHeader))) / sizeof(CasDiskIndexEntry);
+ uint64_t ExpectedEntryCount = (Size - sizeof(CasDiskIndexHeader)) / sizeof(CasDiskIndexEntry);
CasDiskIndexHeader Header;
ObjectIndexFile.Read(&Header, sizeof(Header), 0);
if ((Header.Magic == CasDiskIndexHeader::ExpectedMagic) && (Header.Version == CasDiskIndexHeader::CurrentVersion) &&
@@ -1348,6 +1348,8 @@ CasContainerStrategy::OpenContainer(bool IsNewStore)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("store.compactcas");
+
TEST_CASE("compactcas.hex")
{
uint32_t Value;
@@ -2159,6 +2161,8 @@ TEST_CASE("compactcas.iteratechunks")
}
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp
index 31b3a68c4..0088afe6e 100644
--- a/src/zenstore/filecas.cpp
+++ b/src/zenstore/filecas.cpp
@@ -383,7 +383,7 @@ FileCasStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, CasStore::
HRESULT WriteRes = PayloadFile.Write(Cursor, Size);
if (FAILED(WriteRes))
{
- ThrowSystemException(hRes, fmt::format("failed to write {} bytes to shard file '{}'", ChunkSize, ChunkPath));
+ ThrowSystemException(WriteRes, fmt::format("failed to write {} bytes to shard file '{}'", ChunkSize, ChunkPath));
}
};
#else
@@ -669,8 +669,8 @@ FileCasStrategy::IterateChunks(std::span<IoHash> ChunkHashes,
return true;
};
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -823,8 +823,8 @@ FileCasStrategy::ScrubStorage(ScrubContext& Ctx)
ZEN_INFO("discovered {} files @ '{}' ({} not in index), scrubbing", m_Index.size(), m_RootDirectory, DiscoveredFilesNotInIndex);
- std::atomic<bool> Abort;
- std::atomic<bool> Pause;
+ std::atomic<bool> Abort{false};
+ std::atomic<bool> Pause{false};
ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog);
try
@@ -1016,7 +1016,7 @@ FileCasStrategy::MakeIndexSnapshot(bool ResetLog)
{
// This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in
// the end it will be the same result
- ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message());
+ ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message());
}
m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite);
}
@@ -1052,7 +1052,7 @@ FileCasStrategy::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t&
uint64_t Size = ObjectIndexFile.FileSize();
if (Size >= sizeof(FileCasIndexHeader))
{
- uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(FileCasIndexHeader))) / sizeof(FileCasIndexEntry);
+ uint64_t ExpectedEntryCount = (Size - sizeof(FileCasIndexHeader)) / sizeof(FileCasIndexEntry);
FileCasIndexHeader Header;
ObjectIndexFile.Read(&Header, sizeof(Header), 0);
if ((Header.Magic == FileCasIndexHeader::ExpectedMagic) && (Header.Version == FileCasIndexHeader::CurrentVersion) &&
@@ -1496,6 +1496,8 @@ FileCasStrategy::CreateReferencePruner(GcCtx& Ctx, GcReferenceStoreStats&)
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("store.filecas");
+
TEST_CASE("cas.chunk.mismatch")
{
}
@@ -1793,6 +1795,8 @@ TEST_CASE("cas.file.move")
# endif
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zenstore/filecas.h b/src/zenstore/filecas.h
index e93356927..41756b65f 100644
--- a/src/zenstore/filecas.h
+++ b/src/zenstore/filecas.h
@@ -74,7 +74,7 @@ private:
{
static const uint32_t kTombStone = 0x0000'0001;
- bool IsFlagSet(const uint32_t Flag) const { return (Flags & kTombStone) == Flag; }
+ bool IsFlagSet(const uint32_t Flag) const { return (Flags & Flag) == Flag; }
IoHash Key;
uint32_t Flags = 0;
diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp
index 14caa5abf..b3450b805 100644
--- a/src/zenstore/gc.cpp
+++ b/src/zenstore/gc.cpp
@@ -1494,7 +1494,8 @@ GcManager::CollectGarbage(const GcSettings& Settings)
GcReferenceValidatorStats& Stats = Result.ReferenceValidatorStats[It.second].second;
try
{
- // Go through all the ReferenceCheckers to see if the list of Cids the collector selected are referenced or
+ // Go through all the ReferenceCheckers to see if the list of Cids the collector selected
+ // are referenced or not
SCOPED_TIMER(Stats.ElapsedMS = std::chrono::milliseconds(Timer.GetElapsedTimeMs()););
ReferenceValidator->Validate(Ctx, Stats);
}
@@ -1952,7 +1953,7 @@ GcScheduler::AppendGCLog(std::string_view Id, GcClock::TimePoint StartTime, cons
Writer << "SingleThread"sv << Settings.SingleThread;
Writer << "CompactBlockUsageThresholdPercent"sv << Settings.CompactBlockUsageThresholdPercent;
Writer << "AttachmentRangeMin"sv << Settings.AttachmentRangeMin;
- Writer << "AttachmentRangeMax"sv << Settings.AttachmentRangeMin;
+ Writer << "AttachmentRangeMax"sv << Settings.AttachmentRangeMax;
Writer << "ForceStoreCacheAttachmentMetaData"sv << Settings.StoreCacheAttachmentMetaData;
Writer << "ForceStoreProjectAttachmentMetaData"sv << Settings.StoreProjectAttachmentMetaData;
Writer << "EnableValidation"sv << Settings.EnableValidation;
@@ -2893,7 +2894,7 @@ GcScheduler::CollectGarbage(const GcClock::TimePoint& CacheExpireTime,
{
m_LastFullGCV2Result = Result;
m_LastFullAttachmentRangeMin = AttachmentRangeMin;
- m_LastFullAttachmentRangeMin = AttachmentRangeMax;
+ m_LastFullAttachmentRangeMax = AttachmentRangeMax;
}
Diff.DiskSize = Result.CompactStoresStatSum.RemovedDisk;
Diff.MemorySize = Result.ReferencerStatSum.RemoveExpiredDataStats.FreedMemory;
@@ -3048,6 +3049,8 @@ GcScheduler::CollectGarbage(const GcClock::TimePoint& CacheExpireTime,
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("store.gc");
+
TEST_CASE("gc.diskusagewindow")
{
DiskUsageWindow Stats;
@@ -3379,6 +3382,8 @@ TEST_CASE("gc.attachmentrange")
CHECK(AttachmentRangeMax == IoHash::Max);
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zenstore/include/zenstore/buildstore/buildstore.h b/src/zenstore/include/zenstore/buildstore/buildstore.h
index 76cba05b9..ea2ef7f89 100644
--- a/src/zenstore/include/zenstore/buildstore/buildstore.h
+++ b/src/zenstore/include/zenstore/buildstore/buildstore.h
@@ -1,5 +1,5 @@
-
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include <zenstore/blockstore.h>
@@ -223,7 +223,7 @@ private:
uint64_t m_MetaLogFlushPosition = 0;
std::unique_ptr<std::vector<IoHash>> m_TrackedBlobKeys;
- std::atomic<uint64_t> m_LastAccessTimeUpdateCount;
+ std::atomic<uint64_t> m_LastAccessTimeUpdateCount{0};
friend class BuildStoreGcReferenceChecker;
friend class BuildStoreGcReferencePruner;
diff --git a/src/zenstore/include/zenstore/cache/cachedisklayer.h b/src/zenstore/include/zenstore/cache/cachedisklayer.h
index 3d684587d..393e289ac 100644
--- a/src/zenstore/include/zenstore/cache/cachedisklayer.h
+++ b/src/zenstore/include/zenstore/cache/cachedisklayer.h
@@ -153,14 +153,14 @@ public:
struct BucketStats
{
- uint64_t DiskSize;
- uint64_t MemorySize;
- uint64_t DiskHitCount;
- uint64_t DiskMissCount;
- uint64_t DiskWriteCount;
- uint64_t MemoryHitCount;
- uint64_t MemoryMissCount;
- uint64_t MemoryWriteCount;
+ uint64_t DiskSize = 0;
+ uint64_t MemorySize = 0;
+ uint64_t DiskHitCount = 0;
+ uint64_t DiskMissCount = 0;
+ uint64_t DiskWriteCount = 0;
+ uint64_t MemoryHitCount = 0;
+ uint64_t MemoryMissCount = 0;
+ uint64_t MemoryWriteCount = 0;
metrics::RequestStatsSnapshot PutOps;
metrics::RequestStatsSnapshot GetOps;
};
@@ -174,8 +174,8 @@ public:
struct DiskStats
{
std::vector<NamedBucketStats> BucketStats;
- uint64_t DiskSize;
- uint64_t MemorySize;
+ uint64_t DiskSize = 0;
+ uint64_t MemorySize = 0;
};
struct PutResult
@@ -395,12 +395,12 @@ public:
TCasLogFile<DiskIndexEntry> m_SlogFile;
uint64_t m_LogFlushPosition = 0;
- std::atomic<uint64_t> m_DiskHitCount;
- std::atomic<uint64_t> m_DiskMissCount;
- std::atomic<uint64_t> m_DiskWriteCount;
- std::atomic<uint64_t> m_MemoryHitCount;
- std::atomic<uint64_t> m_MemoryMissCount;
- std::atomic<uint64_t> m_MemoryWriteCount;
+ std::atomic<uint64_t> m_DiskHitCount{0};
+ std::atomic<uint64_t> m_DiskMissCount{0};
+ std::atomic<uint64_t> m_DiskWriteCount{0};
+ std::atomic<uint64_t> m_MemoryHitCount{0};
+ std::atomic<uint64_t> m_MemoryMissCount{0};
+ std::atomic<uint64_t> m_MemoryWriteCount{0};
metrics::RequestStats m_PutOps;
metrics::RequestStats m_GetOps;
@@ -540,7 +540,7 @@ private:
Configuration m_Configuration;
std::atomic_uint64_t m_TotalMemCachedSize{};
std::atomic_bool m_IsMemCacheTrimming = false;
- std::atomic<GcClock::Tick> m_NextAllowedTrimTick;
+ std::atomic<GcClock::Tick> m_NextAllowedTrimTick{};
mutable RwLock m_Lock;
BucketMap_t m_Buckets;
std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets;
diff --git a/src/zenstore/include/zenstore/cache/cacheshared.h b/src/zenstore/include/zenstore/cache/cacheshared.h
index 791720589..8e9cd7fd7 100644
--- a/src/zenstore/include/zenstore/cache/cacheshared.h
+++ b/src/zenstore/include/zenstore/cache/cacheshared.h
@@ -40,12 +40,12 @@ struct CacheValueDetails
{
struct ValueDetails
{
- uint64_t Size;
- uint64_t RawSize;
+ uint64_t Size = 0;
+ uint64_t RawSize = 0;
IoHash RawHash;
GcClock::Tick LastAccess{};
std::vector<IoHash> Attachments;
- ZenContentType ContentType;
+ ZenContentType ContentType = ZenContentType::kBinary;
};
struct BucketDetails
diff --git a/src/zenstore/include/zenstore/cache/structuredcachestore.h b/src/zenstore/include/zenstore/cache/structuredcachestore.h
index 5a0a8b069..3722a0d31 100644
--- a/src/zenstore/include/zenstore/cache/structuredcachestore.h
+++ b/src/zenstore/include/zenstore/cache/structuredcachestore.h
@@ -70,9 +70,9 @@ public:
struct NamespaceStats
{
- uint64_t HitCount;
- uint64_t MissCount;
- uint64_t WriteCount;
+ uint64_t HitCount = 0;
+ uint64_t MissCount = 0;
+ uint64_t WriteCount = 0;
metrics::RequestStatsSnapshot PutOps;
metrics::RequestStatsSnapshot GetOps;
ZenCacheDiskLayer::DiskStats DiskStats;
@@ -342,11 +342,11 @@ private:
void LogWorker();
RwLock m_LogQueueLock;
std::vector<AccessLogItem> m_LogQueue;
- std::atomic_bool m_ExitLogging;
+ std::atomic_bool m_ExitLogging{false};
Event m_LogEvent;
std::thread m_AsyncLoggingThread;
- std::atomic_bool m_WriteLogEnabled;
- std::atomic_bool m_AccessLogEnabled;
+ std::atomic_bool m_WriteLogEnabled{false};
+ std::atomic_bool m_AccessLogEnabled{false};
friend class CacheStoreReferenceChecker;
};
diff --git a/src/zenstore/include/zenstore/caslog.h b/src/zenstore/include/zenstore/caslog.h
index f3dd32fb1..7967d9dae 100644
--- a/src/zenstore/include/zenstore/caslog.h
+++ b/src/zenstore/include/zenstore/caslog.h
@@ -20,8 +20,8 @@ public:
kTruncate
};
- static bool IsValid(std::filesystem::path FileName, size_t RecordSize);
- void Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode);
+ static bool IsValid(const std::filesystem::path& FileName, size_t RecordSize);
+ void Open(const std::filesystem::path& FileName, size_t RecordSize, Mode Mode);
void Append(const void* DataPointer, uint64_t DataSize);
void Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount);
void Flush();
@@ -48,7 +48,7 @@ private:
static_assert(sizeof(FileHeader) == 64);
private:
- void Open(std::filesystem::path FileName, size_t RecordSize, BasicFile::Mode Mode);
+ void Open(const std::filesystem::path& FileName, size_t RecordSize, BasicFile::Mode Mode);
BasicFile m_File;
FileHeader m_Header;
@@ -60,8 +60,8 @@ template<typename T>
class TCasLogFile : public CasLogFile
{
public:
- static bool IsValid(std::filesystem::path FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); }
- void Open(std::filesystem::path FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); }
+ static bool IsValid(const std::filesystem::path& FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); }
+ void Open(const std::filesystem::path& FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); }
// This should be called before the Replay() is called to do some basic sanity checking
bool Initialize() { return true; }
diff --git a/src/zenstore/include/zenstore/gc.h b/src/zenstore/include/zenstore/gc.h
index 734d2e5a7..67cf852f9 100644
--- a/src/zenstore/include/zenstore/gc.h
+++ b/src/zenstore/include/zenstore/gc.h
@@ -238,7 +238,7 @@ bool FilterReferences(GcCtx& Ctx, std::string_view Context, std::vector<IoHa
/**
* @brief An interface to implement a lock for Stop The World (from writing new data)
*
- * This interface is registered/unregistered to GcManager vua AddGcReferenceLocker() and RemoveGcReferenceLockerr()
+ * This interface is registered/unregistered to GcManager via AddGcReferenceLocker() and RemoveGcReferenceLocker()
*/
class GcReferenceLocker
{
@@ -443,8 +443,8 @@ struct GcSchedulerState
uint64_t DiskFree = 0;
GcClock::TimePoint LastFullGcTime{};
GcClock::TimePoint LastLightweightGcTime{};
- std::chrono::seconds RemainingTimeUntilLightweightGc;
- std::chrono::seconds RemainingTimeUntilFullGc;
+ std::chrono::seconds RemainingTimeUntilLightweightGc{};
+ std::chrono::seconds RemainingTimeUntilFullGc{};
uint64_t RemainingSpaceUntilFullGC = 0;
std::chrono::milliseconds LastFullGcDuration{};
@@ -562,7 +562,7 @@ private:
GcClock::TimePoint m_LastGcExpireTime{};
IoHash m_LastFullAttachmentRangeMin = IoHash::Zero;
IoHash m_LastFullAttachmentRangeMax = IoHash::Max;
- uint8_t m_AttachmentPassIndex;
+ uint8_t m_AttachmentPassIndex = 0;
std::chrono::milliseconds m_LastFullGcDuration{};
GcStorageSize m_LastFullGCDiff;
diff --git a/src/zenstore/include/zenstore/projectstore.h b/src/zenstore/include/zenstore/projectstore.h
index 33ef996db..6f49cd024 100644
--- a/src/zenstore/include/zenstore/projectstore.h
+++ b/src/zenstore/include/zenstore/projectstore.h
@@ -67,8 +67,8 @@ public:
struct OplogEntryAddress
{
- uint32_t Offset; // note: Multiple of m_OpsAlign!
- uint32_t Size;
+ uint32_t Offset = 0; // note: Multiple of m_OpsAlign!
+ uint32_t Size = 0;
};
struct OplogEntry
@@ -80,11 +80,7 @@ public:
uint32_t Reserved;
inline bool IsTombstone() const { return OpCoreAddress.Offset == 0 && OpCoreAddress.Size == 0 && OpLsn.Number; }
- inline void MakeTombstone()
- {
- OpLsn = {};
- OpCoreAddress.Offset = OpCoreAddress.Size = OpCoreHash = Reserved = 0;
- }
+ inline void MakeTombstone() { OpCoreAddress.Offset = OpCoreAddress.Size = OpCoreHash = Reserved = 0; }
};
static_assert(IsPow2(sizeof(OplogEntry)));
diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp
index 1ab2b317a..03086b473 100644
--- a/src/zenstore/projectstore.cpp
+++ b/src/zenstore/projectstore.cpp
@@ -1488,7 +1488,7 @@ ProjectStore::Oplog::Read()
else
{
std::vector<OplogEntry> OpLogEntries;
- uint64_t InvalidEntries;
+ uint64_t InvalidEntries = 0;
m_Storage->ReadOplogEntriesFromLog(OpLogEntries, InvalidEntries, m_LogFlushPosition);
for (const OplogEntry& OpEntry : OpLogEntries)
{
@@ -1750,8 +1750,8 @@ ProjectStore::Oplog::Validate(const std::filesystem::path& ProjectRootDir,
}
};
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -2373,7 +2373,7 @@ ProjectStore::Oplog::IterateChunks(const std::filesystem::path& P
else if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end())
{
CidChunkIndexes.push_back(ChunkIndex);
- CidChunkHashes.push_back(ChunkIt->second);
+ CidChunkHashes.push_back(MetaIt->second);
}
else if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end())
{
@@ -2384,8 +2384,8 @@ ProjectStore::Oplog::IterateChunks(const std::filesystem::path& P
}
if (OptionalWorkerPool)
{
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -3817,7 +3817,7 @@ ProjectStore::Project::OpenOplog(std::string_view OplogId, bool AllowCompact, bo
std::filesystem::path DeletePath;
if (!RemoveOplog(OplogId, DeletePath))
{
- ZEN_WARN("Failed to clean up deleted oplog {}/{}", Identifier, OplogId, OplogBasePath);
+ ZEN_WARN("Failed to clean up deleted oplog {}/{} at '{}'", Identifier, OplogId, OplogBasePath);
}
ReOpen = true;
@@ -4053,8 +4053,8 @@ ProjectStore::Project::Scrub(ScrubContext& Ctx)
RwLock::SharedLockScope _(m_ProjectLock);
- std::atomic<bool> Abort;
- std::atomic<bool> Pause;
+ std::atomic<bool> Abort{false};
+ std::atomic<bool> Pause{false};
ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog);
try
@@ -4360,7 +4360,7 @@ ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcMa
, m_DiskWriteBlocker(Gc.GetDiskWriteBlocker())
{
ZEN_INFO("initializing project store at '{}'", m_ProjectBasePath);
- // m_Log.set_level(spdlog::level::debug);
+ // m_Log.SetLogLevel(zen::logging::Debug);
m_Gc.AddGcStorage(this);
m_Gc.AddGcReferencer(*this);
m_Gc.AddGcReferenceLocker(*this);
@@ -4433,8 +4433,8 @@ ProjectStore::Flush()
}
WorkerThreadPool& WorkerPool = GetSmallWorkerPool(EWorkloadType::Burst);
- std::atomic<bool> AbortFlag;
- std::atomic<bool> PauseFlag;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
try
{
@@ -4712,6 +4712,13 @@ ProjectStore::GetProjectsList()
Response << "ProjectRootDir"sv << PathToUtf8(Prj.ProjectRootDir);
Response << "EngineRootDir"sv << PathToUtf8(Prj.EngineRootDir);
Response << "ProjectFilePath"sv << PathToUtf8(Prj.ProjectFilePath);
+
+ const auto AccessTime = Prj.LastOplogAccessTime(""sv);
+ if (AccessTime != GcClock::TimePoint::min())
+ {
+ Response << "LastAccessTime"sv << gsl::narrow<uint64_t>(AccessTime.time_since_epoch().count());
+ }
+
Response.EndObject();
});
Response.EndArray();
@@ -4974,7 +4981,7 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl
}
if (WantsRawSizeField)
{
- ZEN_ASSERT_SLOW(Sizes[Index] == (uint64_t)-1);
+ ZEN_ASSERT_SLOW(RawSizes[Index] == (uint64_t)-1);
RawSizes[Index] = Payload.GetSize();
}
}
@@ -5762,7 +5769,7 @@ public:
}
}
- for (auto ProjectIt : m_ProjectStore.m_Projects)
+ for (const auto& ProjectIt : m_ProjectStore.m_Projects)
{
Ref<ProjectStore::Project> Project = ProjectIt.second;
std::vector<std::string> OplogsToCompact = Project->GetOplogsToCompact();
@@ -6802,6 +6809,8 @@ namespace testutils {
} // namespace testutils
+TEST_SUITE_BEGIN("store.projectstore");
+
TEST_CASE("project.opkeys")
{
using namespace std::literals;
@@ -8473,6 +8482,8 @@ TEST_CASE("project.store.iterateoplog")
}
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zenstore/workspaces.cpp b/src/zenstore/workspaces.cpp
index f0f975af4..ad21bbc68 100644
--- a/src/zenstore/workspaces.cpp
+++ b/src/zenstore/workspaces.cpp
@@ -383,7 +383,7 @@ Workspace::GetShares() const
{
std::vector<Ref<WorkspaceShare>> Shares;
Shares.reserve(m_Shares.size());
- for (auto It : m_Shares)
+ for (const auto& It : m_Shares)
{
Shares.push_back(It.second);
}
@@ -435,7 +435,7 @@ Workspaces::RefreshWorkspaceShares(const Oid& WorkspaceId)
Workspace = FindWorkspace(Lock, WorkspaceId);
if (Workspace)
{
- for (auto Share : Workspace->GetShares())
+ for (const auto& Share : Workspace->GetShares())
{
DeletedShares.insert(Share->GetConfig().Id);
}
@@ -482,6 +482,12 @@ Workspaces::RefreshWorkspaceShares(const Oid& WorkspaceId)
m_ShareAliases.erase(Share->GetConfig().Alias);
}
Workspace->SetShare(Configuration.Id, std::move(NewShare));
+ if (!Configuration.Alias.empty())
+ {
+ m_ShareAliases.insert_or_assign(
+ Configuration.Alias,
+ ShareAlias{.WorkspaceId = WorkspaceId, .ShareId = Configuration.Id});
+ }
}
}
else
@@ -602,7 +608,7 @@ Workspaces::GetWorkspaceShareChunks(const Oid& WorkspaceId,
{
RequestedOffset = Size;
}
- if ((RequestedOffset + RequestedSize) > Size)
+ if (RequestedSize > Size - RequestedOffset)
{
RequestedSize = Size - RequestedOffset;
}
@@ -649,7 +655,7 @@ Workspaces::GetWorkspaces() const
{
std::vector<Oid> Workspaces;
RwLock::SharedLockScope Lock(m_Lock);
- for (auto It : m_Workspaces)
+ for (const auto& It : m_Workspaces)
{
Workspaces.push_back(It.first);
}
@@ -679,7 +685,7 @@ Workspaces::GetWorkspaceShares(const Oid& WorkspaceId) const
if (Workspace)
{
std::vector<Oid> Shares;
- for (auto Share : Workspace->GetShares())
+ for (const auto& Share : Workspace->GetShares())
{
Shares.push_back(Share->GetConfig().Id);
}
@@ -1356,6 +1362,8 @@ namespace {
} // namespace
+TEST_SUITE_BEGIN("store.workspaces");
+
TEST_CASE("workspaces.scanfolder")
{
using namespace std::literals;
@@ -1559,6 +1567,8 @@ TEST_CASE("workspace.share.alias")
CHECK(!WS.GetShareAlias("my_share").has_value());
}
+TEST_SUITE_END();
+
#endif
void
diff --git a/src/zentelemetry-test/zentelemetry-test.cpp b/src/zentelemetry-test/zentelemetry-test.cpp
index 83fd549db..5a2ac74de 100644
--- a/src/zentelemetry-test/zentelemetry-test.cpp
+++ b/src/zentelemetry-test/zentelemetry-test.cpp
@@ -1,45 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/trace.h>
+#include <zencore/testing.h>
#include <zentelemetry/zentelemetry.h>
#include <zencore/memory/newdelete.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
-
int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zentelemetry_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zenstore-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
-
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zentelemetry-test", zen::zentelemetry_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zentelemetry/include/zentelemetry/otlpencoder.h b/src/zentelemetry/include/zentelemetry/otlpencoder.h
index ed6665781..f280aa9ec 100644
--- a/src/zentelemetry/include/zentelemetry/otlpencoder.h
+++ b/src/zentelemetry/include/zentelemetry/otlpencoder.h
@@ -13,9 +13,9 @@
# include <protozero/pbf_builder.hpp>
# include <protozero/types.hpp>
-namespace spdlog { namespace details {
- struct log_msg;
-}} // namespace spdlog::details
+namespace zen::logging {
+struct LogMessage;
+} // namespace zen::logging
namespace zen::otel {
enum class Resource : protozero::pbf_tag_type;
@@ -46,7 +46,7 @@ public:
void AddResourceAttribute(const std::string_view& Key, const std::string_view& Value);
void AddResourceAttribute(const std::string_view& Key, int64_t Value);
- std::string FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const;
+ std::string FormatOtelProtobuf(const logging::LogMessage& Msg) const;
std::string FormatOtelMetrics() const;
std::string FormatOtelTrace(zen::otel::TraceId Trace, std::span<const zen::otel::Span*> Spans) const;
diff --git a/src/zentelemetry/include/zentelemetry/otlptrace.h b/src/zentelemetry/include/zentelemetry/otlptrace.h
index 49dd90358..95718af55 100644
--- a/src/zentelemetry/include/zentelemetry/otlptrace.h
+++ b/src/zentelemetry/include/zentelemetry/otlptrace.h
@@ -317,6 +317,7 @@ public:
ExtendableStringBuilder<128> NameBuilder;
NamingFunction(NameBuilder);
+ Initialize(NameBuilder);
}
/** Construct a new span with a naming function AND initializer function
@@ -350,7 +351,13 @@ public:
// Execute a function with the span pointer if valid. This can
// be used to add attributes or events to the span after creation
- inline void WithSpan(auto Func) const { Func(*m_Span); }
+ inline void WithSpan(auto Func) const
+ {
+ if (m_Span)
+ {
+ Func(*m_Span);
+ }
+ }
private:
void Initialize(std::string_view Name);
diff --git a/src/zentelemetry/include/zentelemetry/stats.h b/src/zentelemetry/include/zentelemetry/stats.h
index 3e67bac1c..260b0fcfb 100644
--- a/src/zentelemetry/include/zentelemetry/stats.h
+++ b/src/zentelemetry/include/zentelemetry/stats.h
@@ -16,11 +16,17 @@ class CbObjectWriter;
namespace zen::metrics {
+/** A single atomic value that can be set and read at any time.
+ *
+ * Useful for point-in-time readings such as queue depth, active connection count,
+ * or any value where only the current state matters rather than history.
+ */
template<typename T>
class Gauge
{
public:
Gauge() : m_Value{0} {}
+ explicit Gauge(T InitialValue) : m_Value{InitialValue} {}
T Value() const { return m_Value; }
void SetValue(T Value) { m_Value = Value; }
@@ -29,12 +35,12 @@ private:
std::atomic<T> m_Value;
};
-/** Stats counter
+/** Monotonically increasing (or decreasing) counter.
*
- * A counter is modified by adding or subtracting a value from a current value.
- * This would typically be used to track number of requests in flight, number
- * of active jobs etc
+ * Suitable for tracking quantities that go up and down over time, such as
+ * requests in flight or active jobs. All operations are lock-free via atomics.
*
+ * Unlike a Meter, a Counter does not track rates — it only records a running total.
*/
class Counter
{
@@ -50,34 +56,56 @@ private:
std::atomic<uint64_t> m_count{0};
};
-/** Exponential Weighted Moving Average
-
- This is very raw, to use as little state as possible. If we
- want to use this more broadly in user code we should perhaps
- add a more user-friendly wrapper
+/** Low-level exponential weighted moving average.
+ *
+ * Tracks a smoothed rate using the standard EWMA recurrence:
+ *
+ * rate = rate + alpha * (instantRate - rate)
+ *
+ * where instantRate = Count / Interval. The alpha value controls how quickly
+ * the average responds to changes — higher alpha means more weight on recent
+ * samples. Typical alphas are derived from a decay half-life (e.g. 1, 5, 15
+ * minutes) and a fixed tick interval.
+ *
+ * This class is intentionally minimal to keep per-instance state to a single
+ * atomic double. See Meter for a more convenient wrapper.
*/
-
class RawEWMA
{
public:
- /// <summary>
- /// Update EWMA with new measure
- /// </summary>
- /// <param name="Alpha">Smoothing factor (between 0 and 1)</param>
- /// <param name="Interval">Elapsed time since last</param>
- /// <param name="Count">Value</param>
- /// <param name="IsInitialUpdate">Whether this is the first update or not</param>
- void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate);
+ /** Update the EWMA with a new observation.
+ *
+ * @param Alpha Smoothing factor in (0, 1). Smaller values give a
+ * slower-moving average; larger values track recent
+ * changes more aggressively.
+ * @param Interval Elapsed hi-freq timer ticks since the last Tick call.
+ * Used to compute the instantaneous rate as Count/Interval.
+ * @param Count Number of events observed during this interval.
+ * @param IsInitialUpdate True on the very first call: seeds the rate directly
+ * from the instantaneous rate rather than blending it in.
+ */
+ void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate);
+
+ /** Returns the current smoothed rate in events per second. */
double Rate() const;
private:
std::atomic<double> m_Rate = 0;
};
-/// <summary>
-/// Tracks rate of events over time (i.e requests/sec), using
-/// exponential moving averages
-/// </summary>
+/** Tracks the rate of events over time using exponential moving averages.
+ *
+ * Maintains three EWMA windows (1, 5, 15 minutes) in addition to a simple
+ * mean rate computed from the total count and elapsed wall time since
+ * construction. This mirrors the load-average conventions familiar from Unix.
+ *
+ * Rate updates are batched: Mark() accumulates a pending count and the EWMA
+ * is only advanced every ~5 seconds (controlled by kTickIntervalInSeconds),
+ * keeping contention low even under heavy call rates. Rates are returned in
+ * events per second.
+ *
+ * All operations are thread-safe via lock-free atomics.
+ */
class Meter
{
public:
@@ -85,18 +113,18 @@ public:
~Meter();
inline uint64_t Count() const { return m_TotalCount; }
- double Rate1(); // One-minute rate
- double Rate5(); // Five-minute rate
- double Rate15(); // Fifteen-minute rate
- double MeanRate() const; // Mean rate since instantiation of this meter
+ double Rate1(); // One-minute EWMA rate (events/sec)
+ double Rate5(); // Five-minute EWMA rate (events/sec)
+ double Rate15(); // Fifteen-minute EWMA rate (events/sec)
+ double MeanRate() const; // Mean rate since instantiation (events/sec)
void Mark(uint64_t Count = 1); // Register one or more events
private:
std::atomic<uint64_t> m_TotalCount{0}; // Accumulator counting number of marks since beginning
- std::atomic<uint64_t> m_PendingCount{0}; // Pending EWMA update accumulator
- std::atomic<uint64_t> m_StartTick{0}; // Time this was instantiated (for mean)
- std::atomic<uint64_t> m_LastTick{0}; // Timestamp of last EWMA tick
- std::atomic<int64_t> m_Remainder{0}; // Tracks the "modulo" of tick time
+ std::atomic<uint64_t> m_PendingCount{0}; // Pending EWMA update accumulator; drained on each tick
+ std::atomic<uint64_t> m_StartTick{0}; // Hi-freq timer value at construction (for MeanRate)
+ std::atomic<uint64_t> m_LastTick{0}; // Hi-freq timer value of the last EWMA tick
+ std::atomic<int64_t> m_Remainder{0}; // Accumulated ticks not yet consumed by EWMA updates
bool m_IsFirstTick = true;
RawEWMA m_RateM1;
RawEWMA m_RateM5;
@@ -106,7 +134,14 @@ private:
void Tick();
};
-/** Moment-in-time snapshot of a distribution
+/** Immutable sorted snapshot of a reservoir sample.
+ *
+ * Constructed from a vector of sampled values which are sorted on construction.
+ * Percentiles are computed on demand via linear interpolation between adjacent
+ * sorted values, following the standard R-7 quantile method.
+ *
+ * Because this is a copy of the reservoir at a point in time, it can be held
+ * and queried without holding any locks on the source UniformSample.
*/
class SampleSnapshot
{
@@ -128,12 +163,19 @@ private:
std::vector<double> m_Values;
};
-/** Randomly selects samples from a stream. Uses Vitter's
- Algorithm R to produce a statistically representative sample.
-
- http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir
+/** Reservoir sampler for probabilistic distribution tracking.
+ *
+ * Maintains a fixed-size reservoir of samples drawn uniformly from the full
+ * history of values using Vitter's Algorithm R. This gives an unbiased
+ * statistical representation of the value distribution regardless of how many
+ * total values have been observed, at the cost of O(ReservoirSize) memory.
+ *
+ * A larger reservoir improves accuracy of tail percentiles (P99, P999) but
+ * increases memory and snapshot cost. The default of 1028 gives good accuracy
+ * for most telemetry uses.
+ *
+ * http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir
*/
-
class UniformSample
{
public:
@@ -159,7 +201,14 @@ private:
std::vector<std::atomic<int64_t>> m_Values;
};
-/** Track (probabilistic) sample distribution along with min/max
+/** Tracks the statistical distribution of a stream of values.
+ *
+ * Records exact min, max, count and mean across all values ever seen, plus a
+ * reservoir sample (via UniformSample) used to compute percentiles. Percentiles
+ * are therefore probabilistic — they reflect the distribution of a representative
+ * sample rather than the full history.
+ *
+ * All operations are thread-safe via lock-free atomics.
*/
class Histogram
{
@@ -183,11 +232,28 @@ private:
std::atomic<int64_t> m_Count{0};
};
-/** Track timing and frequency of some operation
-
- Example usage would be to track frequency and duration of network
- requests, or function calls.
-
+/** Combines a Histogram and a Meter to track both the distribution and rate
+ * of a recurring operation.
+ *
+ * Duration values are stored in hi-freq timer ticks. Use GetHifreqTimerToSeconds()
+ * when converting for display.
+ *
+ * Typical usage via the RAII Scope helper:
+ *
+ * OperationTiming MyTiming;
+ *
+ * {
+ * OperationTiming::Scope Scope(MyTiming);
+ * DoWork();
+ * // Scope destructor calls Stop() automatically
+ * }
+ *
+ * // Or cancel if the operation should not be counted:
+ * {
+ * OperationTiming::Scope Scope(MyTiming);
+ * if (CacheHit) { Scope.Cancel(); return; }
+ * DoExpensiveWork();
+ * }
*/
class OperationTiming
{
@@ -207,13 +273,19 @@ public:
double Rate15() { return m_Meter.Rate15(); }
double MeanRate() const { return m_Meter.MeanRate(); }
+ /** RAII helper that records duration from construction to Stop() or destruction.
+ *
+ * Call Cancel() to discard the measurement (e.g. for cache hits that should
+ * not skew latency statistics). After Stop() or Cancel() the destructor is a
+ * no-op.
+ */
struct Scope
{
Scope(OperationTiming& Outer);
~Scope();
- void Stop();
- void Cancel();
+ void Stop(); // Record elapsed time and mark the meter
+ void Cancel(); // Discard this measurement; destructor becomes a no-op
private:
OperationTiming& m_Outer;
@@ -225,6 +297,7 @@ private:
Histogram m_Histogram;
};
+/** Immutable snapshot of a Meter's state at a point in time. */
struct MeterSnapshot
{
uint64_t Count;
@@ -234,6 +307,12 @@ struct MeterSnapshot
double Rate15;
};
+/** Immutable snapshot of a Histogram's state at a point in time.
+ *
+ * Count and all statistical values have been scaled by the ConversionFactor
+ * supplied when the snapshot was taken (e.g. GetHifreqTimerToSeconds() to
+ * convert timer ticks to seconds).
+ */
struct HistogramSnapshot
{
double Count;
@@ -246,24 +325,29 @@ struct HistogramSnapshot
double P999;
};
+/** Combined snapshot of a Meter and Histogram pair. */
struct StatsSnapshot
{
MeterSnapshot Meter;
HistogramSnapshot Histogram;
};
+/** Combined snapshot of request timing and byte transfer statistics. */
struct RequestStatsSnapshot
{
StatsSnapshot Requests;
StatsSnapshot Bytes;
};
-/** Metrics for network requests
-
- Aggregates tracking of duration, payload sizes into a single
- class
-
- */
+/** Tracks both the timing and payload size of network requests.
+ *
+ * Maintains two independent histogram+meter pairs: one for request duration
+ * (in hi-freq timer ticks) and one for transferred bytes. Both dimensions
+ * share the same request count — a single Update() call advances both.
+ *
+ * Duration accessors return values in hi-freq timer ticks. Multiply by
+ * GetHifreqTimerToSeconds() to convert to seconds.
+ */
class RequestStats
{
public:
@@ -275,9 +359,9 @@ public:
// Timing
- int64_t MaxDuration() const { return m_BytesHistogram.Max(); }
- int64_t MinDuration() const { return m_BytesHistogram.Min(); }
- double MeanDuration() const { return m_BytesHistogram.Mean(); }
+ int64_t MaxDuration() const { return m_RequestTimeHistogram.Max(); }
+ int64_t MinDuration() const { return m_RequestTimeHistogram.Min(); }
+ double MeanDuration() const { return m_RequestTimeHistogram.Mean(); }
SampleSnapshot DurationSnapshot() const { return m_RequestTimeHistogram.Snapshot(); }
double Rate1() { return m_RequestMeter.Rate1(); }
double Rate5() { return m_RequestMeter.Rate5(); }
@@ -295,14 +379,23 @@ public:
double ByteRate15() { return m_BytesMeter.Rate15(); }
double ByteMeanRate() const { return m_BytesMeter.MeanRate(); }
+ /** RAII helper that records duration and byte count from construction to Stop()
+ * or destruction.
+ *
+ * The byte count can be supplied at construction or updated at any point via
+ * SetBytes() before the scope ends — useful when the response size is not
+ * known until the operation completes.
+ *
+ * Call Cancel() to discard the measurement entirely.
+ */
struct Scope
{
Scope(RequestStats& Outer, int64_t Bytes);
~Scope();
void SetBytes(int64_t Bytes) { m_Bytes = Bytes; }
- void Stop();
- void Cancel();
+ void Stop(); // Record elapsed time and byte count
+ void Cancel(); // Discard this measurement; destructor becomes a no-op
private:
RequestStats& m_Outer;
diff --git a/src/zentelemetry/otlpencoder.cpp b/src/zentelemetry/otlpencoder.cpp
index 677545066..5477c5381 100644
--- a/src/zentelemetry/otlpencoder.cpp
+++ b/src/zentelemetry/otlpencoder.cpp
@@ -3,9 +3,9 @@
#include "zentelemetry/otlpencoder.h"
#include <zenbase/zenbase.h>
+#include <zencore/logging/logmsg.h>
#include <zentelemetry/otlptrace.h>
-#include <spdlog/sinks/sink.h>
#include <zencore/testing.h>
#include <protozero/buffer_string.hpp>
@@ -29,49 +29,49 @@ OtlpEncoder::~OtlpEncoder()
}
static int
-MapSeverity(const spdlog::level::level_enum Level)
+MapSeverity(const logging::LogLevel Level)
{
switch (Level)
{
- case spdlog::level::critical:
+ case logging::Critical:
return otel::SEVERITY_NUMBER_FATAL;
- case spdlog::level::err:
+ case logging::Err:
return otel::SEVERITY_NUMBER_ERROR;
- case spdlog::level::warn:
+ case logging::Warn:
return otel::SEVERITY_NUMBER_WARN;
- case spdlog::level::info:
+ case logging::Info:
return otel::SEVERITY_NUMBER_INFO;
- case spdlog::level::debug:
+ case logging::Debug:
return otel::SEVERITY_NUMBER_DEBUG;
default:
- case spdlog::level::trace:
+ case logging::Trace:
return otel::SEVERITY_NUMBER_TRACE;
}
}
static const char*
-MapSeverityText(const spdlog::level::level_enum Level)
+MapSeverityText(const logging::LogLevel Level)
{
switch (Level)
{
- case spdlog::level::critical:
+ case logging::Critical:
return "fatal";
- case spdlog::level::err:
+ case logging::Err:
return "error";
- case spdlog::level::warn:
+ case logging::Warn:
return "warn";
- case spdlog::level::info:
+ case logging::Info:
return "info";
- case spdlog::level::debug:
+ case logging::Debug:
return "debug";
default:
- case spdlog::level::trace:
+ case logging::Trace:
return "trace";
}
}
std::string
-OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const
+OtlpEncoder::FormatOtelProtobuf(const logging::LogMessage& Msg) const
{
std::string Data;
@@ -98,7 +98,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const
protozero::pbf_builder<otel::InstrumentationScope> IsBuilder{SlBuilder,
otel::ScopeLogs::required_InstrumentationScope_scope};
- IsBuilder.add_string(otel::InstrumentationScope::string_name, Msg.logger_name.data(), Msg.logger_name.size());
+ IsBuilder.add_string(otel::InstrumentationScope::string_name, Msg.GetLoggerName().data(), Msg.GetLoggerName().size());
}
// LogRecord log_records
@@ -106,13 +106,13 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const
protozero::pbf_builder<otel::LogRecord> LrBuilder{SlBuilder, otel::ScopeLogs::required_repeated_LogRecord_log_records};
LrBuilder.add_fixed64(otel::LogRecord::required_fixed64_time_unix_nano,
- std::chrono::duration_cast<std::chrono::nanoseconds>(Msg.time.time_since_epoch()).count());
+ std::chrono::duration_cast<std::chrono::nanoseconds>(Msg.GetTime().time_since_epoch()).count());
- const int Severity = MapSeverity(Msg.level);
+ const int Severity = MapSeverity(Msg.GetLevel());
LrBuilder.add_enum(otel::LogRecord::optional_SeverityNumber_severity_number, Severity);
- LrBuilder.add_string(otel::LogRecord::optional_string_severity_text, MapSeverityText(Msg.level));
+ LrBuilder.add_string(otel::LogRecord::optional_string_severity_text, MapSeverityText(Msg.GetLevel()));
otel::TraceId TraceId;
const otel::SpanId SpanId = otel::Span::GetCurrentSpanId(TraceId);
@@ -127,7 +127,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const
{
protozero::pbf_builder<otel::AnyValue> BodyBuilder{LrBuilder, otel::LogRecord::optional_anyvalue_body};
- BodyBuilder.add_string(otel::AnyValue::string_string_value, Msg.payload.data(), Msg.payload.size());
+ BodyBuilder.add_string(otel::AnyValue::string_string_value, Msg.GetPayload().data(), Msg.GetPayload().size());
}
// attributes
@@ -139,7 +139,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const
{
protozero::pbf_builder<otel::AnyValue> AvBuilder{KvBuilder, otel::KeyValue::AnyValue_value};
- AvBuilder.add_int64(otel::AnyValue::int64_int_value, Msg.thread_id);
+ AvBuilder.add_int64(otel::AnyValue::int64_int_value, Msg.GetThreadId());
}
}
}
diff --git a/src/zentelemetry/otlptrace.cpp b/src/zentelemetry/otlptrace.cpp
index 6a095cfeb..3888717d5 100644
--- a/src/zentelemetry/otlptrace.cpp
+++ b/src/zentelemetry/otlptrace.cpp
@@ -385,6 +385,8 @@ otlptrace_forcelink()
# if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("telemetry.otlptrace");
+
TEST_CASE("otlp.trace")
{
// Enable OTLP tracing for the duration of this test
@@ -409,6 +411,8 @@ TEST_CASE("otlp.trace")
}
}
+TEST_SUITE_END();
+
# endif
} // namespace zen::otel
diff --git a/src/zentelemetry/stats.cpp b/src/zentelemetry/stats.cpp
index c67fa3c66..a417bb52c 100644
--- a/src/zentelemetry/stats.cpp
+++ b/src/zentelemetry/stats.cpp
@@ -631,7 +631,7 @@ EmitSnapshot(const HistogramSnapshot& Snapshot, CbObjectWriter& Cbo)
{
Cbo << "t_count" << Snapshot.Count << "t_avg" << Snapshot.Avg;
Cbo << "t_min" << Snapshot.Min << "t_max" << Snapshot.Max;
- Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P999;
+ Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P99 << "t_p999" << Snapshot.P999;
}
void
@@ -660,6 +660,8 @@ EmitSnapshot(std::string_view Tag, const RequestStatsSnapshot& Snapshot, CbObjec
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("telemetry.stats");
+
TEST_CASE("Core.Stats.Histogram")
{
Histogram Histo{258};
@@ -819,6 +821,8 @@ TEST_CASE("Meter")
# endif
}
+TEST_SUITE_END();
+
namespace zen {
void
diff --git a/src/zentelemetry/xmake.lua b/src/zentelemetry/xmake.lua
index 7739c0a08..cd9a18ec4 100644
--- a/src/zentelemetry/xmake.lua
+++ b/src/zentelemetry/xmake.lua
@@ -6,5 +6,5 @@ target('zentelemetry')
add_headerfiles("**.h")
add_files("**.cpp")
add_includedirs("include", {public=true})
- add_deps("zencore", "protozero", "spdlog")
+ add_deps("zencore", "protozero")
add_deps("robin-map")
diff --git a/src/zentest-appstub/xmake.lua b/src/zentest-appstub/xmake.lua
index 97615e322..844ba82ef 100644
--- a/src/zentest-appstub/xmake.lua
+++ b/src/zentest-appstub/xmake.lua
@@ -5,6 +5,7 @@ target("zentest-appstub")
set_group("tests")
add_headerfiles("**.h")
add_files("*.cpp")
+ add_deps("zencore")
if is_os("linux") then
add_syslinks("pthread")
diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp
index 24cf21e97..509629739 100644
--- a/src/zentest-appstub/zentest-appstub.cpp
+++ b/src/zentest-appstub/zentest-appstub.cpp
@@ -1,33 +1,418 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/stream.h>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+#endif
+
+#include <fmt/format.h>
+
#include <stdio.h>
+#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <cstring>
+#include <filesystem>
+#include <string>
+#include <system_error>
#include <thread>
-using namespace std::chrono_literals;
+using namespace std::literals;
+using namespace zen;
+
+#if !defined(_MSC_VER)
+# define _strnicmp strncasecmp // TEMPORARY WORKAROUND - should not be using this
+#endif
+
+// Some basic functions to implement some test "compute" functions
+
+std::string
+Rot13Function(std::string_view InputString)
+{
+ std::string OutputString{InputString};
+
+ std::transform(OutputString.begin(),
+ OutputString.end(),
+ OutputString.begin(),
+ [](std::string::value_type c) -> std::string::value_type {
+ if (c >= 'a' && c <= 'z')
+ {
+ return 'a' + (c - 'a' + 13) % 26;
+ }
+ else if (c >= 'A' && c <= 'Z')
+ {
+ return 'A' + (c - 'A' + 13) % 26;
+ }
+ else
+ {
+ return c;
+ }
+ });
+
+ return OutputString;
+}
+
+std::string
+ReverseFunction(std::string_view InputString)
+{
+ std::string OutputString{InputString};
+ std::reverse(OutputString.begin(), OutputString.end());
+ return OutputString;
+}
+
+std::string
+IdentityFunction(std::string_view InputString)
+{
+ return std::string{InputString};
+}
+
+std::string
+NullFunction(std::string_view)
+{
+ return {};
+}
+
+zen::CbObject
+DescribeFunctions()
+{
+ CbObjectWriter Versions;
+ Versions << "BuildSystemVersion" << Guid::FromString("17fe280d-ccd8-4be8-a9d1-89c944a70969"sv);
+
+ Versions.BeginArray("Functions"sv);
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Null"sv;
+ Versions << "Version"sv << Guid::FromString("00000000-0000-0000-0000-000000000000"sv);
+ Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Identity"sv;
+ Versions << "Version"sv << Guid::FromString("11111111-1111-1111-1111-111111111111"sv);
+ Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Rot13"sv;
+ Versions << "Version"sv << Guid::FromString("13131313-1313-1313-1313-131313131313"sv);
+ Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Reverse"sv;
+ Versions << "Version"sv << Guid::FromString("31313131-3131-3131-3131-313131313131"sv);
+ Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Sleep"sv;
+ Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv);
+ Versions.EndObject();
+ Versions.EndArray();
+
+ return Versions.Save();
+}
+
+struct ContentResolver
+{
+ std::filesystem::path InputsRoot;
+
+ CompressedBuffer ResolveChunk(IoHash Hash, uint64_t ExpectedSize)
+ {
+ std::filesystem::path ChunkPath = InputsRoot / Hash.ToHexString();
+ IoBuffer ChunkBuffer = IoBufferBuilder::MakeFromFile(ChunkPath);
+
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer AsCompressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkBuffer), RawHash, RawSize);
+
+ if (RawSize != ExpectedSize)
+ {
+ throw std::runtime_error(
+ fmt::format("chunk size mismatch - expected {}, got {} for '{}'", ExpectedSize, ChunkBuffer.Size(), ChunkPath));
+ }
+ if (RawHash != Hash)
+ {
+ throw std::runtime_error(fmt::format("chunk hash mismatch - expected {}, got {} for '{}'", Hash, RawHash, ChunkPath));
+ }
+
+ return AsCompressed;
+ }
+};
+
+zen::CbPackage
+ExecuteFunction(CbObject Action, ContentResolver ChunkResolver)
+{
+ auto Apply = [&](auto Func) {
+ zen::CbPackage Result;
+ auto Source = Action["Inputs"sv].AsObjectView()["Source"sv].AsObjectView();
+
+ IoHash InputRawHash = Source["RawHash"sv].AsHash();
+ uint64_t InputRawSize = Source["RawSize"sv].AsUInt64();
+
+ zen::CompressedBuffer InputData = ChunkResolver.ResolveChunk(InputRawHash, InputRawSize);
+ SharedBuffer Input = InputData.Decompress();
+
+ std::string Output = Func(std::string_view(static_cast<const char*>(Input.GetData()), Input.GetSize()));
+ zen::CompressedBuffer OutputData =
+ zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Output), OodleCompressor::Selkie, OodleCompressionLevel::HyperFast4);
+ IoHash OutputRawHash = OutputData.DecodeRawHash();
+
+ CbAttachment OutputAttachment(std::move(OutputData), OutputRawHash);
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("Values"sv);
+ Cbo.BeginObject();
+ Cbo << "Id" << Oid{1, 2, 3};
+ Cbo.AddAttachment("RawHash", OutputAttachment);
+ Cbo << "RawSize" << Output.size();
+ Cbo.EndObject();
+ Cbo.EndArray();
+
+ Result.SetObject(Cbo.Save());
+ Result.AddAttachment(std::move(OutputAttachment));
+ return Result;
+ };
+
+ std::string_view Function = Action["Function"sv].AsString();
+
+ if (Function == "Rot13"sv)
+ {
+ return Apply(Rot13Function);
+ }
+ else if (Function == "Reverse"sv)
+ {
+ return Apply(ReverseFunction);
+ }
+ else if (Function == "Identity"sv)
+ {
+ return Apply(IdentityFunction);
+ }
+ else if (Function == "Null"sv)
+ {
+ return Apply(NullFunction);
+ }
+ else if (Function == "Sleep"sv)
+ {
+ uint64_t SleepTimeMs = Action["Constants"sv].AsObjectView()["SleepTimeMs"sv].AsUInt64();
+ zen::Sleep(static_cast<int>(SleepTimeMs));
+ return Apply(IdentityFunction);
+ }
+ else
+ {
+ return {};
+ }
+}
+
+/* This implements a minimal application to help testing of process launch-related
+ functionality
+
+ It also mimics the DDC2 worker command line interface, so it may be used to
+ exercise compute infrastructure.
+ */
int
main(int argc, char* argv[])
{
int ExitCode = 0;
- for (int i = 0; i < argc; ++i)
+ try
{
- if (std::strncmp(argv[i], "-t=", 3) == 0)
+ std::filesystem::path BasePath = std::filesystem::current_path();
+ std::filesystem::path InputPath = std::filesystem::current_path() / "Inputs";
+ std::filesystem::path OutputPath = std::filesystem::current_path() / "Outputs";
+ std::filesystem::path VersionPath = std::filesystem::current_path() / "Versions";
+ std::vector<std::filesystem::path> ActionPaths;
+
+ /*
+ GetSwitchValues(TEXT("-B="), ActionPathPatterns);
+ GetSwitchValues(TEXT("-Build="), ActionPathPatterns);
+
+ GetSwitchValues(TEXT("-I="), InputDirectoryPaths);
+ GetSwitchValues(TEXT("-Input="), InputDirectoryPaths);
+
+ GetSwitchValues(TEXT("-O="), OutputDirectoryPaths);
+ GetSwitchValues(TEXT("-Output="), OutputDirectoryPaths);
+
+ GetSwitchValues(TEXT("-V="), VersionPaths);
+ GetSwitchValues(TEXT("-Version="), VersionPaths);
+ */
+
+ auto SplitArg = [](const char* Arg) -> std::string_view {
+ std::string_view ArgView{Arg};
+ if (auto SplitPos = ArgView.find_first_of('='); SplitPos != std::string_view::npos)
+ {
+ return ArgView.substr(SplitPos + 1);
+ }
+ else
+ {
+ return {};
+ }
+ };
+
+ auto ParseIntArg = [](std::string_view Arg) -> int {
+ int Rv = 0;
+ const auto Result = std::from_chars(Arg.data(), Arg.data() + Arg.size(), Rv);
+
+ if (Result.ec != std::errc{})
+ {
+ throw std::invalid_argument(fmt::format("bad argument (not an integer): {}", Arg).c_str());
+ }
+
+ return Rv;
+ };
+
+ for (int i = 1; i < argc; ++i)
{
- const int SleepTime = std::atoi(argv[i] + 3);
+ std::string_view Arg = argv[i];
+
+ if (Arg.compare(0, 1, "-"))
+ {
+ continue;
+ }
+
+ if (std::strncmp(argv[i], "-t=", 3) == 0)
+ {
+ const int SleepTime = std::atoi(argv[i] + 3);
+
+ printf("[zentest] sleeping for %ds...\n", SleepTime);
+
+ std::this_thread::sleep_for(SleepTime * 1s);
+ }
+ else if (std::strncmp(argv[i], "-f=", 3) == 0)
+ {
+ // Force a "failure" process exit code to return to the invoker
+
+ // This may throw for invalid arguments, which makes this useful for
+ // testing exception handling
+ std::string_view ErrorArg = SplitArg(argv[i]);
+ ExitCode = ParseIntArg(ErrorArg);
+ }
+ else if ((_strnicmp(argv[i], "-input=", 7) == 0) || (_strnicmp(argv[i], "-i=", 3) == 0))
+ {
+ /* mimic DDC2
+
+ GetSwitchValues(TEXT("-I="), InputDirectoryPaths);
+ GetSwitchValues(TEXT("-Input="), InputDirectoryPaths);
+ */
+
+ std::string_view InputArg = SplitArg(argv[i]);
+ InputPath = InputArg;
+ }
+ else if ((_strnicmp(argv[i], "-output=", 8) == 0) || (_strnicmp(argv[i], "-o=", 3) == 0))
+ {
+ /* mimic DDC2 handling of where files storing output chunk files are directed
+
+ GetSwitchValues(TEXT("-O="), OutputDirectoryPaths);
+ GetSwitchValues(TEXT("-Output="), OutputDirectoryPaths);
+ */
- printf("[zentest] sleeping for %ds...\n", SleepTime);
+ std::string_view OutputArg = SplitArg(argv[i]);
+ OutputPath = OutputArg;
+ }
+ else if ((_strnicmp(argv[i], "-version=", 8) == 0) || (_strnicmp(argv[i], "-v=", 3) == 0))
+ {
+ /* mimic DDC2
- std::this_thread::sleep_for(SleepTime * 1s);
+ GetSwitchValues(TEXT("-V="), VersionPaths);
+ GetSwitchValues(TEXT("-Version="), VersionPaths);
+ */
+
+ std::string_view VersionArg = SplitArg(argv[i]);
+ VersionPath = VersionArg;
+ }
+ else if ((_strnicmp(argv[i], "-build=", 7) == 0) || (_strnicmp(argv[i], "-b=", 3) == 0))
+ {
+ /* mimic DDC2
+
+ GetSwitchValues(TEXT("-B="), ActionPathPatterns);
+ GetSwitchValues(TEXT("-Build="), ActionPathPatterns);
+ */
+
+ std::string_view BuildActionArg = SplitArg(argv[i]);
+ std::filesystem::path ActionPath{BuildActionArg};
+ ActionPaths.push_back(ActionPath);
+
+ ExitCode = 0;
+ }
}
- else if (std::strncmp(argv[i], "-f=", 3) == 0)
+
+ // Emit version information
+
+ if (!VersionPath.empty())
{
- ExitCode = std::atoi(argv[i] + 3);
+ CbObjectWriter Version;
+
+ Version << "BuildSystemVersion" << Guid::FromString("17fe280d-ccd8-4be8-a9d1-89c944a70969"sv);
+
+ Version.BeginArray("Functions");
+
+ Version.BeginObject();
+ Version << "Name"
+ << "Rot13"
+ << "Version" << Guid::FromString("13131313-1313-1313-1313-131313131313"sv);
+ Version.EndObject();
+
+ Version.BeginObject();
+ Version << "Name"
+ << "Reverse"
+ << "Version" << Guid::FromString("98765432-1000-0000-0000-000000000000"sv);
+ Version.EndObject();
+
+ Version.BeginObject();
+ Version << "Name"
+ << "Identity"
+ << "Version" << Guid::FromString("11111111-1111-1111-1111-111111111111"sv);
+ Version.EndObject();
+
+ Version.BeginObject();
+ Version << "Name"
+ << "Null"
+ << "Version" << Guid::FromString("00000000-0000-0000-0000-000000000000"sv);
+ Version.EndObject();
+
+ Version.EndArray();
+ CbObject VersionObject = Version.Save();
+
+ BinaryWriter Writer;
+ zen::SaveCompactBinary(Writer, VersionObject);
+ zen::WriteFile(VersionPath, IoBufferBuilder::MakeFromMemory(Writer.GetView()));
+ }
+
+ // Evaluate actions
+
+ ContentResolver Resolver;
+ Resolver.InputsRoot = InputPath;
+
+ for (std::filesystem::path ActionPath : ActionPaths)
+ {
+ IoBuffer ActionDescBuffer = ReadFile(ActionPath).Flatten();
+ CbObject ActionDesc = LoadCompactBinaryObject(ActionDescBuffer);
+ CbPackage Result = ExecuteFunction(ActionDesc, Resolver);
+ CbObject ResultObject = Result.GetObject();
+
+ BinaryWriter Writer;
+ zen::SaveCompactBinary(Writer, ResultObject);
+ zen::WriteFile(ActionPath.replace_extension(".output"), IoBufferBuilder::MakeFromMemory(Writer.GetView()));
+
+ // Also marshal outputs
+
+ for (const auto& Attachment : Result.GetAttachments())
+ {
+ const CompositeBuffer& AttachmentBuffer = Attachment.AsCompressedBinary().GetCompressed();
+ zen::WriteFile(OutputPath / Attachment.GetHash().ToHexString(), AttachmentBuffer.Flatten().AsIoBuffer());
+ }
}
}
+ catch (std::exception& Ex)
+ {
+ printf("[zentest] exception caught in main: '%s'\n", Ex.what());
+
+ ExitCode = 99;
+ }
printf("[zentest] exiting with exit code: %d\n", ExitCode);
diff --git a/src/zenutil-test/zenutil-test.cpp b/src/zenutil-test/zenutil-test.cpp
index f5cfd5a72..e2b6ac9bd 100644
--- a/src/zenutil-test/zenutil-test.cpp
+++ b/src/zenutil-test/zenutil-test.cpp
@@ -1,45 +1,15 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-#include <zencore/trace.h>
+#include <zencore/testing.h>
#include <zenutil/zenutil.h>
#include <zencore/memory/newdelete.h>
-#if ZEN_WITH_TESTS
-# define ZEN_TEST_WITH_RUNNER 1
-# include <zencore/testing.h>
-# include <zencore/process.h>
-#endif
-
int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
-#if ZEN_PLATFORM_WINDOWS
- setlocale(LC_ALL, "en_us.UTF8");
-#endif // ZEN_PLATFORM_WINDOWS
-
#if ZEN_WITH_TESTS
- zen::zenutil_forcelinktests();
-
-# if ZEN_PLATFORM_LINUX
- zen::IgnoreChildSignals();
-# endif
-
-# if ZEN_WITH_TRACE
- zen::TraceInit("zenutil-test");
- zen::TraceOptions TraceCommandlineOptions;
- if (GetTraceOptionsFromCommandline(TraceCommandlineOptions))
- {
- TraceConfigure(TraceCommandlineOptions);
- }
-# endif // ZEN_WITH_TRACE
-
- zen::logging::InitializeLogging();
- zen::MaximizeOpenFileCount();
-
- return ZEN_RUN_TESTS(argc, argv);
+ return zen::testing::RunTestMain(argc, argv, "zenutil-test", zen::zenutil_forcelinktests);
#else
return 0;
#endif
diff --git a/src/zenutil/commandlineoptions.cpp b/src/zenutil/config/commandlineoptions.cpp
index d94564843..25f5522d8 100644
--- a/src/zenutil/commandlineoptions.cpp
+++ b/src/zenutil/config/commandlineoptions.cpp
@@ -1,7 +1,8 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenutil/commandlineoptions.h>
+#include <zenutil/config/commandlineoptions.h>
+#include <zencore/filesystem.h>
#include <zencore/string.h>
#include <filesystem>
@@ -194,6 +195,8 @@ commandlineoptions_forcelink()
{
}
+TEST_SUITE_BEGIN("util.commandlineoptions");
+
TEST_CASE("CommandLine")
{
std::vector<std::string> v1 = ParseCommandLine("c:\\my\\exe.exe \"quoted arg\" \"one\",two,\"three\\\"");
@@ -235,5 +238,7 @@ TEST_CASE("CommandLine")
CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64"));
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zenutil/environmentoptions.cpp b/src/zenutil/config/environmentoptions.cpp
index ee40086c1..fb7f71706 100644
--- a/src/zenutil/environmentoptions.cpp
+++ b/src/zenutil/config/environmentoptions.cpp
@@ -1,6 +1,6 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenutil/environmentoptions.h>
+#include <zenutil/config/environmentoptions.h>
#include <zencore/filesystem.h>
diff --git a/src/zenutil/config/loggingconfig.cpp b/src/zenutil/config/loggingconfig.cpp
new file mode 100644
index 000000000..5092c60aa
--- /dev/null
+++ b/src/zenutil/config/loggingconfig.cpp
@@ -0,0 +1,77 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenutil/config/loggingconfig.h"
+
+#include <zenbase/zenbase.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cxxopts.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+void
+ZenLoggingCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig)
+{
+ // clang-format off
+ options.add_options("logging")
+ ("abslog", "Path to log file", cxxopts::value<std::string>(m_AbsLogFile))
+ ("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(LoggingConfig.LogId))
+ ("quiet", "Configure console logger output to level WARN", cxxopts::value<bool>(LoggingConfig.QuietConsole)->default_value("false"))
+ ("noconsole", "Disable console logging", cxxopts::value<bool>(LoggingConfig.NoConsoleOutput)->default_value("false"))
+ ("log-trace", "Change selected loggers to level TRACE", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Trace]))
+ ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Debug]))
+ ("log-info", "Change selected loggers to level INFO", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Info]))
+ ("log-warn", "Change selected loggers to level WARN", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Warn]))
+ ("log-error", "Change selected loggers to level ERROR", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Err]))
+ ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Critical]))
+ ("log-off", "Change selected loggers to level OFF", cxxopts::value<std::string>(LoggingConfig.Loggers[logging::Off]))
+ ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value<std::string>(LoggingConfig.OtelEndpointUri))
+ ;
+ // clang-format on
+}
+
+void
+ZenLoggingCmdLineOptions::ApplyOptions(ZenLoggingConfig& LoggingConfig)
+{
+ LoggingConfig.AbsLogFile = MakeSafeAbsolutePath(m_AbsLogFile);
+}
+
+void
+ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig)
+{
+ ZEN_UNUSED(options);
+
+ if (LoggingConfig.QuietConsole)
+ {
+ bool HasExplicitConsoleLevel = false;
+ for (int i = 0; i < logging::LogLevelCount; ++i)
+ {
+ if (LoggingConfig.Loggers[i].find("console") != std::string::npos)
+ {
+ HasExplicitConsoleLevel = true;
+ break;
+ }
+ }
+
+ if (!HasExplicitConsoleLevel)
+ {
+ std::string& WarnLoggers = LoggingConfig.Loggers[logging::Warn];
+ if (!WarnLoggers.empty())
+ {
+ WarnLoggers += ",";
+ }
+ WarnLoggers += "console";
+ }
+ }
+
+ for (int i = 0; i < logging::LogLevelCount; ++i)
+ {
+ logging::ConfigureLogLevels(logging::LogLevel(i), LoggingConfig.Loggers[i]);
+ }
+ logging::RefreshLogLevels();
+}
+
+} // namespace zen
diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp
new file mode 100644
index 000000000..4410d463d
--- /dev/null
+++ b/src/zenutil/consoletui.cpp
@@ -0,0 +1,483 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/consoletui.h>
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#else
+# include <poll.h>
+# include <sys/ioctl.h>
+# include <termios.h>
+# include <unistd.h>
+#endif
+
+#include <cstdio>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+// Platform-specific terminal helpers
+
+#if ZEN_PLATFORM_WINDOWS
+
+static bool
+CheckIsInteractiveTerminal()
+{
+ DWORD dwMode = 0;
+ return GetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), &dwMode) && GetConsoleMode(GetStdHandle(STD_OUTPUT_HANDLE), &dwMode);
+}
+
+static void
+EnableVirtualTerminal()
+{
+ HANDLE hStdOut = GetStdHandle(STD_OUTPUT_HANDLE);
+ DWORD dwMode = 0;
+ if (GetConsoleMode(hStdOut, &dwMode))
+ {
+ SetConsoleMode(hStdOut, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
+ }
+}
+
+// RAII guard: sets the console output code page for the lifetime of the object and
+// restores the original on destruction. Required for UTF-8 glyphs to render correctly
+// via printf/fflush since the default console code page is not UTF-8.
+class ConsoleCodePageGuard
+{
+public:
+ explicit ConsoleCodePageGuard(UINT NewCP) : m_OldCP(GetConsoleOutputCP()) { SetConsoleOutputCP(NewCP); }
+ ~ConsoleCodePageGuard() { SetConsoleOutputCP(m_OldCP); }
+
+private:
+ UINT m_OldCP;
+};
+
+enum class ConsoleKey
+{
+ Unknown,
+ ArrowUp,
+ ArrowDown,
+ Enter,
+ Escape,
+};
+
+static ConsoleKey
+ReadKey()
+{
+ HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
+ INPUT_RECORD Record{};
+ DWORD dwRead = 0;
+ while (true)
+ {
+ if (!ReadConsoleInputA(hStdin, &Record, 1, &dwRead))
+ {
+ return ConsoleKey::Escape; // treat read error as cancel
+ }
+ if (Record.EventType == KEY_EVENT && Record.Event.KeyEvent.bKeyDown)
+ {
+ switch (Record.Event.KeyEvent.wVirtualKeyCode)
+ {
+ case VK_UP:
+ return ConsoleKey::ArrowUp;
+ case VK_DOWN:
+ return ConsoleKey::ArrowDown;
+ case VK_RETURN:
+ return ConsoleKey::Enter;
+ case VK_ESCAPE:
+ return ConsoleKey::Escape;
+ default:
+ break;
+ }
+ }
+ }
+}
+
+#else // POSIX
+
+static bool
+CheckIsInteractiveTerminal()
+{
+ return isatty(STDIN_FILENO) && isatty(STDOUT_FILENO);
+}
+
+static void
+EnableVirtualTerminal()
+{
+ // ANSI escape codes are native on POSIX terminals; nothing to do
+}
+
+// RAII guard: switches the terminal to raw/unbuffered input mode and restores
+// the original attributes on destruction.
+class RawModeGuard
+{
+public:
+ RawModeGuard()
+ {
+ if (tcgetattr(STDIN_FILENO, &m_OldAttrs) != 0)
+ {
+ return;
+ }
+
+ struct termios Raw = m_OldAttrs;
+ Raw.c_iflag &= ~static_cast<tcflag_t>(BRKINT | ICRNL | INPCK | ISTRIP | IXON);
+ Raw.c_cflag |= CS8;
+ Raw.c_lflag &= ~static_cast<tcflag_t>(ECHO | ICANON | IEXTEN | ISIG);
+ Raw.c_cc[VMIN] = 1;
+ Raw.c_cc[VTIME] = 0;
+ if (tcsetattr(STDIN_FILENO, TCSANOW, &Raw) == 0)
+ {
+ m_Valid = true;
+ }
+ }
+
+ ~RawModeGuard()
+ {
+ if (m_Valid)
+ {
+ tcsetattr(STDIN_FILENO, TCSANOW, &m_OldAttrs);
+ }
+ }
+
+ bool IsValid() const { return m_Valid; }
+
+private:
+ struct termios m_OldAttrs = {};
+ bool m_Valid = false;
+};
+
+static int
+ReadByteWithTimeout(int TimeoutMs)
+{
+ struct pollfd Pfd
+ {
+ STDIN_FILENO, POLLIN, 0
+ };
+ if (poll(&Pfd, 1, TimeoutMs) > 0 && (Pfd.revents & POLLIN))
+ {
+ unsigned char c = 0;
+ if (read(STDIN_FILENO, &c, 1) == 1)
+ {
+ return static_cast<int>(c);
+ }
+ }
+ return -1;
+}
+
+// State for fullscreen live mode (alternate screen + raw input)
+static struct termios s_SavedAttrs = {};
+static bool s_InLiveMode = false;
+
+enum class ConsoleKey
+{
+ Unknown,
+ ArrowUp,
+ ArrowDown,
+ Enter,
+ Escape,
+};
+
+static ConsoleKey
+ReadKey()
+{
+ unsigned char c = 0;
+ if (read(STDIN_FILENO, &c, 1) != 1)
+ {
+ return ConsoleKey::Escape; // treat read error as cancel
+ }
+
+ if (c == 27) // ESC byte or start of an escape sequence
+ {
+ int Next = ReadByteWithTimeout(50);
+ if (Next == '[')
+ {
+ int Final = ReadByteWithTimeout(50);
+ if (Final == 'A')
+ {
+ return ConsoleKey::ArrowUp;
+ }
+ if (Final == 'B')
+ {
+ return ConsoleKey::ArrowDown;
+ }
+ }
+ return ConsoleKey::Escape;
+ }
+
+ if (c == '\r' || c == '\n')
+ {
+ return ConsoleKey::Enter;
+ }
+
+ return ConsoleKey::Unknown;
+}
+
+#endif // ZEN_PLATFORM_WINDOWS / POSIX
+
+//////////////////////////////////////////////////////////////////////////
+// Public API
+
+uint32_t
+TuiConsoleColumns(uint32_t Default)
+{
+#if ZEN_PLATFORM_WINDOWS
+ CONSOLE_SCREEN_BUFFER_INFO Csbi = {};
+ if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &Csbi))
+ {
+ return static_cast<uint32_t>(Csbi.dwSize.X);
+ }
+#else
+ struct winsize Ws = {};
+ if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &Ws) == 0 && Ws.ws_col > 0)
+ {
+ return static_cast<uint32_t>(Ws.ws_col);
+ }
+#endif
+ return Default;
+}
+
+void
+TuiEnableOutput()
+{
+ EnableVirtualTerminal();
+#if ZEN_PLATFORM_WINDOWS
+ SetConsoleOutputCP(CP_UTF8);
+#endif
+}
+
+bool
+TuiIsStdoutTty()
+{
+#if ZEN_PLATFORM_WINDOWS
+ static bool Cached = [] {
+ DWORD dwMode = 0;
+ return GetConsoleMode(GetStdHandle(STD_OUTPUT_HANDLE), &dwMode) != 0;
+ }();
+ return Cached;
+#else
+ static bool Cached = isatty(STDOUT_FILENO) != 0;
+ return Cached;
+#endif
+}
+
+bool
+IsTuiAvailable()
+{
+ static bool Cached = CheckIsInteractiveTerminal();
+ return Cached;
+}
+
+int
+TuiPickOne(std::string_view Title, std::span<const std::string> Items)
+{
+ EnableVirtualTerminal();
+
+#if ZEN_PLATFORM_WINDOWS
+ ConsoleCodePageGuard CodePageGuard(CP_UTF8);
+#else
+ RawModeGuard RawMode;
+ if (!RawMode.IsValid())
+ {
+ return -1;
+ }
+#endif
+
+ const int Count = static_cast<int>(Items.size());
+ int SelectedIndex = 0;
+
+ printf("\n%.*s\n\n", static_cast<int>(Title.size()), Title.data());
+
+ // Hide cursor during interaction
+ printf("\033[?25l");
+
+ // Renders the full entry list and hint footer.
+ // On subsequent calls, moves the cursor back up first to overwrite the previous output.
+ bool FirstRender = true;
+ auto RenderAll = [&] {
+ if (!FirstRender)
+ {
+ printf("\033[%dA", Count + 2); // move up: entries + blank line + hint line
+ }
+ FirstRender = false;
+
+ for (int i = 0; i < Count; ++i)
+ {
+ bool IsSelected = (i == SelectedIndex);
+
+ printf("\r\033[K"); // erase line
+
+ if (IsSelected)
+ {
+ printf("\033[1;7m"); // bold + reverse video
+ }
+
+ // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (▶)
+ const char* Indicator = IsSelected ? " \xe2\x96\xb6 " : " ";
+
+ printf("%s%s", Indicator, Items[i].c_str());
+
+ if (IsSelected)
+ {
+ printf("\033[0m"); // reset attributes
+ }
+
+ printf("\n");
+ }
+
+ // Blank separator line
+ printf("\r\033[K\n");
+
+ // Hint footer
+ // \xe2\x86\x91 = U+2191 ↑ \xe2\x86\x93 = U+2193 ↓
+ printf(
+ "\r\033[K \033[2m\xe2\x86\x91/\xe2\x86\x93\033[0m navigate "
+ "\033[2mEnter\033[0m confirm "
+ "\033[2mEsc\033[0m cancel\n");
+
+ fflush(stdout);
+ };
+
+ RenderAll();
+
+ int Result = -1;
+ bool Done = false;
+ while (!Done)
+ {
+ ConsoleKey Key = ReadKey();
+ switch (Key)
+ {
+ case ConsoleKey::ArrowUp:
+ SelectedIndex = (SelectedIndex - 1 + Count) % Count;
+ RenderAll();
+ break;
+
+ case ConsoleKey::ArrowDown:
+ SelectedIndex = (SelectedIndex + 1) % Count;
+ RenderAll();
+ break;
+
+ case ConsoleKey::Enter:
+ Result = SelectedIndex;
+ Done = true;
+ break;
+
+ case ConsoleKey::Escape:
+ Done = true;
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ // Restore cursor and add a blank line for visual separation
+ printf("\033[?25h\n");
+ fflush(stdout);
+
+ return Result;
+}
+
+void
+TuiEnterAlternateScreen()
+{
+ EnableVirtualTerminal();
+#if ZEN_PLATFORM_WINDOWS
+ SetConsoleOutputCP(CP_UTF8);
+#endif
+
+ printf("\033[?1049h"); // Enter alternate screen buffer
+ printf("\033[?25l"); // Hide cursor
+ fflush(stdout);
+
+#if !ZEN_PLATFORM_WINDOWS
+ if (tcgetattr(STDIN_FILENO, &s_SavedAttrs) == 0)
+ {
+ struct termios Raw = s_SavedAttrs;
+ Raw.c_iflag &= ~static_cast<tcflag_t>(BRKINT | ICRNL | INPCK | ISTRIP | IXON);
+ Raw.c_cflag |= CS8;
+ Raw.c_lflag &= ~static_cast<tcflag_t>(ECHO | ICANON | IEXTEN | ISIG);
+ Raw.c_cc[VMIN] = 1;
+ Raw.c_cc[VTIME] = 0;
+ if (tcsetattr(STDIN_FILENO, TCSANOW, &Raw) == 0)
+ {
+ s_InLiveMode = true;
+ }
+ }
+#endif
+}
+
+void
+TuiExitAlternateScreen()
+{
+ printf("\033[?25h"); // Show cursor
+ printf("\033[?1049l"); // Exit alternate screen buffer
+ fflush(stdout);
+
+#if !ZEN_PLATFORM_WINDOWS
+ if (s_InLiveMode)
+ {
+ tcsetattr(STDIN_FILENO, TCSANOW, &s_SavedAttrs);
+ s_InLiveMode = false;
+ }
+#endif
+}
+
+void
+TuiCursorHome()
+{
+ printf("\033[H");
+}
+
+uint32_t
+TuiConsoleRows(uint32_t Default)
+{
+#if ZEN_PLATFORM_WINDOWS
+ CONSOLE_SCREEN_BUFFER_INFO Csbi = {};
+ if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &Csbi))
+ {
+ return static_cast<uint32_t>(Csbi.srWindow.Bottom - Csbi.srWindow.Top + 1);
+ }
+#else
+ struct winsize Ws = {};
+ if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &Ws) == 0 && Ws.ws_row > 0)
+ {
+ return static_cast<uint32_t>(Ws.ws_row);
+ }
+#endif
+ return Default;
+}
+
+bool
+TuiPollQuit()
+{
+#if ZEN_PLATFORM_WINDOWS
+ HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
+ DWORD dwCount = 0;
+ if (!GetNumberOfConsoleInputEvents(hStdin, &dwCount) || dwCount == 0)
+ {
+ return false;
+ }
+ INPUT_RECORD Record{};
+ DWORD dwRead = 0;
+ while (PeekConsoleInputA(hStdin, &Record, 1, &dwRead) && dwRead > 0)
+ {
+ ReadConsoleInputA(hStdin, &Record, 1, &dwRead);
+ if (Record.EventType == KEY_EVENT && Record.Event.KeyEvent.bKeyDown)
+ {
+ WORD vk = Record.Event.KeyEvent.wVirtualKeyCode;
+ char ch = Record.Event.KeyEvent.uChar.AsciiChar;
+ if (vk == VK_ESCAPE || ch == 'q' || ch == 'Q')
+ {
+ return true;
+ }
+ }
+ }
+ return false;
+#else
+ // Non-blocking read: character 3 = Ctrl+C, 27 = Esc, 'q'/'Q' = quit
+ int b = ReadByteWithTimeout(0);
+ return (b == 3 || b == 27 || b == 'q' || b == 'Q');
+#endif
+}
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/commandlineoptions.h b/src/zenutil/include/zenutil/config/commandlineoptions.h
index 01cceedb1..01cceedb1 100644
--- a/src/zenutil/include/zenutil/commandlineoptions.h
+++ b/src/zenutil/include/zenutil/config/commandlineoptions.h
diff --git a/src/zenutil/include/zenutil/environmentoptions.h b/src/zenutil/include/zenutil/config/environmentoptions.h
index 7418608e4..1ecdf591a 100644
--- a/src/zenutil/include/zenutil/environmentoptions.h
+++ b/src/zenutil/include/zenutil/config/environmentoptions.h
@@ -3,7 +3,7 @@
#pragma once
#include <zencore/string.h>
-#include <zenutil/commandlineoptions.h>
+#include <zenutil/config/commandlineoptions.h>
namespace zen {
diff --git a/src/zenutil/include/zenutil/config/loggingconfig.h b/src/zenutil/include/zenutil/config/loggingconfig.h
new file mode 100644
index 000000000..b55b2d9f7
--- /dev/null
+++ b/src/zenutil/include/zenutil/config/loggingconfig.h
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logbase.h>
+#include <filesystem>
+#include <string>
+
+namespace cxxopts {
+class Options;
+}
+
+namespace zen {
+
+struct ZenLoggingConfig
+{
+ bool NoConsoleOutput = false; // Control default use of stdout for diagnostics
+ bool QuietConsole = false; // Configure console logger output to level WARN
+ std::filesystem::path AbsLogFile; // Absolute path to main log file
+ std::string Loggers[logging::LogLevelCount];
+ std::string LogId; // Id for tagging log output
+ std::string OtelEndpointUri; // OpenTelemetry endpoint URI
+};
+
+void ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig);
+
+class ZenLoggingCmdLineOptions
+{
+public:
+ void AddCliOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig);
+ void ApplyOptions(ZenLoggingConfig& LoggingConfig);
+
+private:
+ std::string m_AbsLogFile;
+};
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h
new file mode 100644
index 000000000..5f74fa82b
--- /dev/null
+++ b/src/zenutil/include/zenutil/consoletui.h
@@ -0,0 +1,60 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <cstdint>
+#include <span>
+#include <string>
+#include <string_view>
+
+namespace zen {
+
+// Returns the width of the console in columns, or Default if it cannot be determined.
+uint32_t TuiConsoleColumns(uint32_t Default = 120);
+
+// Enables ANSI/VT escape code processing and UTF-8 console output.
+// Call once before printing ANSI escape sequences or multi-byte UTF-8 characters via printf.
+// Safe to call multiple times. No-op on POSIX (escape codes are native there).
+void TuiEnableOutput();
+
+// Returns true if stdout is connected to a real terminal (not piped or redirected).
+// Useful for deciding whether to use ANSI escape codes for progress output.
+bool TuiIsStdoutTty();
+
+// Returns true if both stdin and stdout are connected to an interactive terminal
+// (i.e. not piped or redirected). Must be checked before calling TuiPickOne().
+bool IsTuiAvailable();
+
+// Displays a cursor-navigable single-select list in the terminal.
+//
+// - Title: a short description printed once above the list
+// - Items: pre-formatted display labels, one per selectable entry
+//
+// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels.
+// Returns the index of the selected item, or -1 if the user cancelled.
+//
+// Precondition: IsTuiAvailable() must be true.
+int TuiPickOne(std::string_view Title, std::span<const std::string> Items);
+
+// Enter the alternate screen buffer for fullscreen live-update mode.
+// Hides the cursor. On POSIX, switches to raw/unbuffered terminal input.
+// Must be balanced by a call to TuiExitAlternateScreen().
+// Precondition: IsTuiAvailable() must be true.
+void TuiEnterAlternateScreen();
+
+// Exit alternate screen buffer. Restores the cursor and, on POSIX, the original
+// terminal mode. Safe to call even if TuiEnterAlternateScreen() was not called.
+void TuiExitAlternateScreen();
+
+// Move the cursor to the top-left corner of the terminal (row 1, col 1).
+void TuiCursorHome();
+
+// Returns the height of the console in rows, or Default if it cannot be determined.
+uint32_t TuiConsoleRows(uint32_t Default = 40);
+
+// Non-blocking check: returns true if the user has pressed a key that means quit
+// (Esc, 'q', 'Q', or Ctrl+C). Consumes the event if one is pending.
+// Should only be called while in alternate screen mode.
+bool TuiPollQuit();
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/logging.h b/src/zenutil/include/zenutil/logging.h
index 85ddc86cd..95419c274 100644
--- a/src/zenutil/include/zenutil/logging.h
+++ b/src/zenutil/include/zenutil/logging.h
@@ -3,19 +3,12 @@
#pragma once
#include <zencore/logging.h>
+#include <zencore/logging/sink.h>
#include <filesystem>
#include <memory>
#include <string>
-namespace spdlog::sinks {
-class sink;
-}
-
-namespace spdlog {
-using sink_ptr = std::shared_ptr<sinks::sink>;
-}
-
//////////////////////////////////////////////////////////////////////////
//
// Logging utilities
@@ -45,6 +38,6 @@ void FinishInitializeLogging(const LoggingOptions& LoggingOptions);
void InitializeLogging(const LoggingOptions& LoggingOptions);
void ShutdownLogging();
-spdlog::sink_ptr GetFileSink();
+logging::SinkPtr GetFileSink();
} // namespace zen
diff --git a/src/zenutil/include/zenutil/logging/fullformatter.h b/src/zenutil/include/zenutil/logging/fullformatter.h
index 9f245becd..33cb94dae 100644
--- a/src/zenutil/include/zenutil/logging/fullformatter.h
+++ b/src/zenutil/include/zenutil/logging/fullformatter.h
@@ -2,21 +2,19 @@
#pragma once
+#include <zencore/logging/formatter.h>
+#include <zencore/logging/helpers.h>
#include <zencore/memory/llm.h>
#include <zencore/zencore.h>
#include <string_view>
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/formatter.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
namespace zen::logging {
-class full_formatter final : public spdlog::formatter
+class FullFormatter final : public Formatter
{
public:
- full_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch)
+ FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch)
: m_Epoch(Epoch)
, m_LogId(LogId)
, m_LinePrefix(128, ' ')
@@ -24,16 +22,19 @@ public:
{
}
- full_formatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {}
+ FullFormatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {}
- virtual std::unique_ptr<formatter> clone() const override
+ virtual std::unique_ptr<Formatter> Clone() const override
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- // Note: this does not properly clone m_UseFullDate
- return std::make_unique<full_formatter>(m_LogId, m_Epoch);
+ if (m_UseFullDate)
+ {
+ return std::make_unique<FullFormatter>(m_LogId);
+ }
+ return std::make_unique<FullFormatter>(m_LogId, m_Epoch);
}
- virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& OutBuffer) override
+ virtual void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override
{
ZEN_MEMSCOPE(ELLMTag::Logging);
@@ -44,38 +45,38 @@ public:
std::chrono::seconds TimestampSeconds;
- std::chrono::milliseconds millis;
+ std::chrono::milliseconds Millis;
if (m_UseFullDate)
{
- TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(msg.time.time_since_epoch());
+ TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch());
if (TimestampSeconds != m_LastLogSecs)
{
RwLock::ExclusiveLockScope _(m_TimestampLock);
m_LastLogSecs = TimestampSeconds;
- m_CachedLocalTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time));
+ m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
m_CachedDatetime.clear();
m_CachedDatetime.push_back('[');
- spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime);
+ helpers::Pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime);
m_CachedDatetime.push_back('-');
- spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime);
+ helpers::Pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime);
m_CachedDatetime.push_back('-');
- spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime);
+ helpers::Pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime);
m_CachedDatetime.push_back(' ');
- spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime);
+ helpers::Pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime);
m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_min, m_CachedDatetime);
+ helpers::Pad2(m_CachedLocalTm.tm_min, m_CachedDatetime);
m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime);
+ helpers::Pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime);
m_CachedDatetime.push_back('.');
}
- millis = spdlog::details::fmt_helper::time_fraction<std::chrono::milliseconds>(msg.time);
+ Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime());
}
else
{
- auto ElapsedTime = msg.time - m_Epoch;
+ auto ElapsedTime = Msg.GetTime() - m_Epoch;
TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime);
if (m_CacheTimestamp.load() != TimestampSeconds)
@@ -93,15 +94,15 @@ public:
m_CachedDatetime.clear();
m_CachedDatetime.push_back('[');
- spdlog::details::fmt_helper::pad2(LogHours, m_CachedDatetime);
+ helpers::Pad2(LogHours, m_CachedDatetime);
m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(LogMins, m_CachedDatetime);
+ helpers::Pad2(LogMins, m_CachedDatetime);
m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(LogSecs, m_CachedDatetime);
+ helpers::Pad2(LogSecs, m_CachedDatetime);
m_CachedDatetime.push_back('.');
}
- millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds);
+ Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds);
}
{
@@ -109,44 +110,43 @@ public:
OutBuffer.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
}
- spdlog::details::fmt_helper::pad3(static_cast<uint32_t>(millis.count()), OutBuffer);
+ helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer);
OutBuffer.push_back(']');
OutBuffer.push_back(' ');
if (!m_LogId.empty())
{
OutBuffer.push_back('[');
- spdlog::details::fmt_helper::append_string_view(m_LogId, OutBuffer);
+ helpers::AppendStringView(m_LogId, OutBuffer);
OutBuffer.push_back(']');
OutBuffer.push_back(' ');
}
// append logger name if exists
- if (msg.logger_name.size() > 0)
+ if (Msg.GetLoggerName().size() > 0)
{
OutBuffer.push_back('[');
- spdlog::details::fmt_helper::append_string_view(msg.logger_name, OutBuffer);
+ helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer);
OutBuffer.push_back(']');
OutBuffer.push_back(' ');
}
OutBuffer.push_back('[');
// wrap the level name with color
- msg.color_range_start = OutBuffer.size();
- spdlog::details::fmt_helper::append_string_view(spdlog::level::to_string_view(msg.level), OutBuffer);
- msg.color_range_end = OutBuffer.size();
+ Msg.ColorRangeStart = OutBuffer.size();
+ helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer);
+ Msg.ColorRangeEnd = OutBuffer.size();
OutBuffer.push_back(']');
OutBuffer.push_back(' ');
// add source location if present
- if (!msg.source.empty())
+ if (Msg.GetSource())
{
OutBuffer.push_back('[');
- const char* filename =
- spdlog::details::short_filename_formatter<spdlog::details::null_scoped_padder>::basename(msg.source.filename);
- spdlog::details::fmt_helper::append_string_view(filename, OutBuffer);
+ const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename);
+ helpers::AppendStringView(Filename, OutBuffer);
OutBuffer.push_back(':');
- spdlog::details::fmt_helper::append_int(msg.source.line, OutBuffer);
+ helpers::AppendInt(Msg.GetSource().Line, OutBuffer);
OutBuffer.push_back(']');
OutBuffer.push_back(' ');
}
@@ -156,8 +156,9 @@ public:
const size_t LinePrefixCount = Min<size_t>(OutBuffer.size(), m_LinePrefix.size());
- auto ItLineBegin = msg.payload.begin();
- auto ItMessageEnd = msg.payload.end();
+ auto MsgPayload = Msg.GetPayload();
+ auto ItLineBegin = MsgPayload.begin();
+ auto ItMessageEnd = MsgPayload.end();
bool IsFirstline = true;
{
@@ -170,9 +171,9 @@ public:
}
else
{
- spdlog::details::fmt_helper::append_string_view(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer);
+ helpers::AppendStringView(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer);
}
- spdlog::details::fmt_helper::append_string_view(spdlog::string_view_t(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer);
+ helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer);
};
while (ItLineEnd != ItMessageEnd)
@@ -187,7 +188,7 @@ public:
if (ItLineBegin != ItMessageEnd)
{
EmitLine();
- spdlog::details::fmt_helper::append_string_view("\n"sv, OutBuffer);
+ helpers::AppendStringView("\n"sv, OutBuffer);
}
}
}
@@ -197,7 +198,7 @@ private:
std::tm m_CachedLocalTm;
std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)};
std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)};
- spdlog::memory_buf_t m_CachedDatetime;
+ MemoryBuffer m_CachedDatetime;
std::string m_LogId;
std::string m_LinePrefix;
bool m_UseFullDate = true;
diff --git a/src/zenutil/include/zenutil/logging/jsonformatter.h b/src/zenutil/include/zenutil/logging/jsonformatter.h
index 3f660e421..216b1b5e5 100644
--- a/src/zenutil/include/zenutil/logging/jsonformatter.h
+++ b/src/zenutil/include/zenutil/logging/jsonformatter.h
@@ -2,27 +2,26 @@
#pragma once
+#include <zencore/logging/formatter.h>
+#include <zencore/logging/helpers.h>
#include <zencore/memory/llm.h>
#include <zencore/zencore.h>
#include <string_view>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/formatter.h>
-ZEN_THIRD_PARTY_INCLUDES_END
+#include <unordered_map>
namespace zen::logging {
using namespace std::literals;
-class json_formatter final : public spdlog::formatter
+class JsonFormatter final : public Formatter
{
public:
- json_formatter(std::string_view LogId) : m_LogId(LogId) {}
+ JsonFormatter(std::string_view LogId) : m_LogId(LogId) {}
- virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<json_formatter>(m_LogId); }
+ virtual std::unique_ptr<Formatter> Clone() const override { return std::make_unique<JsonFormatter>(m_LogId); }
- virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& dest) override
+ virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
{
ZEN_MEMSCOPE(ELLMTag::Logging);
@@ -30,141 +29,132 @@ public:
using std::chrono::milliseconds;
using std::chrono::seconds;
- auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch());
- if (secs != m_LastLogSecs)
+ auto Secs = std::chrono::duration_cast<seconds>(Msg.GetTime().time_since_epoch());
+ if (Secs != m_LastLogSecs)
{
- m_CachedTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time));
- m_LastLogSecs = secs;
- }
-
- const auto& tm_time = m_CachedTm;
+ RwLock::ExclusiveLockScope _(m_TimestampLock);
+ m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
+ m_LastLogSecs = Secs;
- // cache the date/time part for the next second.
-
- if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0)
- {
+ // cache the date/time part for the next second.
m_CachedDatetime.clear();
- spdlog::details::fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
+ helpers::AppendInt(m_CachedTm.tm_year + 1900, m_CachedDatetime);
m_CachedDatetime.push_back('-');
- spdlog::details::fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
+ helpers::Pad2(m_CachedTm.tm_mon + 1, m_CachedDatetime);
m_CachedDatetime.push_back('-');
- spdlog::details::fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
+ helpers::Pad2(m_CachedTm.tm_mday, m_CachedDatetime);
m_CachedDatetime.push_back(' ');
- spdlog::details::fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
+ helpers::Pad2(m_CachedTm.tm_hour, m_CachedDatetime);
m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
+ helpers::Pad2(m_CachedTm.tm_min, m_CachedDatetime);
m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
+ helpers::Pad2(m_CachedTm.tm_sec, m_CachedDatetime);
m_CachedDatetime.push_back('.');
-
- m_CacheTimestamp = secs;
}
- dest.append("{"sv);
- dest.append("\"time\": \""sv);
- dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
- auto millis = spdlog::details::fmt_helper::time_fraction<milliseconds>(msg.time);
- spdlog::details::fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
- dest.append("\", "sv);
+ helpers::AppendStringView("{"sv, Dest);
+ helpers::AppendStringView("\"time\": \""sv, Dest);
+ {
+ RwLock::SharedLockScope _(m_TimestampLock);
+ Dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
+ }
+ auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime());
+ helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest);
+ helpers::AppendStringView("\", "sv, Dest);
- dest.append("\"status\": \""sv);
- dest.append(spdlog::level::to_string_view(msg.level));
- dest.append("\", "sv);
+ helpers::AppendStringView("\"status\": \""sv, Dest);
+ helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest);
+ helpers::AppendStringView("\", "sv, Dest);
- dest.append("\"source\": \""sv);
- dest.append("zenserver"sv);
- dest.append("\", "sv);
+ helpers::AppendStringView("\"source\": \""sv, Dest);
+ helpers::AppendStringView("zenserver"sv, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
- dest.append("\"service\": \""sv);
- dest.append("zencache"sv);
- dest.append("\", "sv);
+ helpers::AppendStringView("\"service\": \""sv, Dest);
+ helpers::AppendStringView("zencache"sv, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
if (!m_LogId.empty())
{
- dest.append("\"id\": \""sv);
- dest.append(m_LogId);
- dest.append("\", "sv);
+ helpers::AppendStringView("\"id\": \""sv, Dest);
+ helpers::AppendStringView(m_LogId, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
}
- if (msg.logger_name.size() > 0)
+ if (Msg.GetLoggerName().size() > 0)
{
- dest.append("\"logger.name\": \""sv);
- dest.append(msg.logger_name);
- dest.append("\", "sv);
+ helpers::AppendStringView("\"logger.name\": \""sv, Dest);
+ helpers::AppendStringView(Msg.GetLoggerName(), Dest);
+ helpers::AppendStringView("\", "sv, Dest);
}
- if (msg.thread_id != 0)
+ if (Msg.GetThreadId() != 0)
{
- dest.append("\"logger.thread_name\": \""sv);
- spdlog::details::fmt_helper::pad_uint(msg.thread_id, 0, dest);
- dest.append("\", "sv);
+ helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest);
+ helpers::PadUint(Msg.GetThreadId(), 0, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
}
- if (!msg.source.empty())
+ if (Msg.GetSource())
{
- dest.append("\"file\": \""sv);
- WriteEscapedString(
- dest,
- spdlog::details::short_filename_formatter<spdlog::details::null_scoped_padder>::basename(msg.source.filename));
- dest.append("\","sv);
-
- dest.append("\"line\": \""sv);
- dest.append(fmt::format("{}", msg.source.line));
- dest.append("\","sv);
-
- dest.append("\"logger.method_name\": \""sv);
- WriteEscapedString(dest, msg.source.funcname);
- dest.append("\", "sv);
+ helpers::AppendStringView("\"file\": \""sv, Dest);
+ WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename));
+ helpers::AppendStringView("\","sv, Dest);
+
+ helpers::AppendStringView("\"line\": \""sv, Dest);
+ helpers::AppendInt(Msg.GetSource().Line, Dest);
+ helpers::AppendStringView("\","sv, Dest);
}
- dest.append("\"message\": \""sv);
- WriteEscapedString(dest, msg.payload);
- dest.append("\""sv);
+ helpers::AppendStringView("\"message\": \""sv, Dest);
+ WriteEscapedString(Dest, Msg.GetPayload());
+ helpers::AppendStringView("\""sv, Dest);
- dest.append("}\n"sv);
+ helpers::AppendStringView("}\n"sv, Dest);
}
private:
- static inline const std::unordered_map<char, std::string_view> SpecialCharacterMap{{'\b', "\\b"sv},
- {'\f', "\\f"sv},
- {'\n', "\\n"sv},
- {'\r', "\\r"sv},
- {'\t', "\\t"sv},
- {'"', "\\\""sv},
- {'\\', "\\\\"sv}};
-
- static void WriteEscapedString(spdlog::memory_buf_t& dest, const spdlog::string_view_t& payload)
+ static inline const std::unordered_map<char, std::string_view> s_SpecialCharacterMap{{'\b', "\\b"sv},
+ {'\f', "\\f"sv},
+ {'\n', "\\n"sv},
+ {'\r', "\\r"sv},
+ {'\t', "\\t"sv},
+ {'"', "\\\""sv},
+ {'\\', "\\\\"sv}};
+
+ static void WriteEscapedString(MemoryBuffer& Dest, const std::string_view& Text)
{
- const char* RangeStart = payload.begin();
- for (const char* It = RangeStart; It != payload.end(); ++It)
+ const char* RangeStart = Text.data();
+ const char* End = Text.data() + Text.size();
+ for (const char* It = RangeStart; It != End; ++It)
{
- if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end())
+ if (auto SpecialIt = s_SpecialCharacterMap.find(*It); SpecialIt != s_SpecialCharacterMap.end())
{
if (RangeStart != It)
{
- dest.append(RangeStart, It);
+ Dest.append(RangeStart, It);
}
- dest.append(SpecialIt->second);
+ helpers::AppendStringView(SpecialIt->second, Dest);
RangeStart = It + 1;
}
}
- if (RangeStart != payload.end())
+ if (RangeStart != End)
{
- dest.append(RangeStart, payload.end());
+ Dest.append(RangeStart, End);
}
};
std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0};
std::chrono::seconds m_LastLogSecs{0};
- std::chrono::seconds m_CacheTimestamp{0};
- spdlog::memory_buf_t m_CachedDatetime;
+ MemoryBuffer m_CachedDatetime;
std::string m_LogId;
+ RwLock m_TimestampLock;
};
} // namespace zen::logging
diff --git a/src/zenutil/include/zenutil/logging/rotatingfilesink.h b/src/zenutil/include/zenutil/logging/rotatingfilesink.h
index 8901b7779..cebc5b110 100644
--- a/src/zenutil/include/zenutil/logging/rotatingfilesink.h
+++ b/src/zenutil/include/zenutil/logging/rotatingfilesink.h
@@ -3,14 +3,11 @@
#pragma once
#include <zencore/basicfile.h>
+#include <zencore/logging/formatter.h>
+#include <zencore/logging/messageonlyformatter.h>
+#include <zencore/logging/sink.h>
#include <zencore/memory/llm.h>
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/details/log_msg.h>
-#include <spdlog/pattern_formatter.h>
-#include <spdlog/sinks/sink.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
#include <atomic>
#include <filesystem>
@@ -19,13 +16,14 @@ namespace zen::logging {
// Basically the same functionality as spdlog::sinks::rotating_file_sink with the biggest difference
// being that it just ignores any errors when writing/rotating files and keeps chugging on.
// It will keep trying to log, and if it starts to work it will continue to log.
-class RotatingFileSink : public spdlog::sinks::sink
+class RotatingFileSink : public Sink
{
public:
RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false)
: m_BaseFilename(BaseFilename)
, m_MaxSize(MaxSize)
, m_MaxFiles(MaxFiles)
+ , m_Formatter(std::make_unique<MessageOnlyFormatter>())
{
ZEN_MEMSCOPE(ELLMTag::Logging);
@@ -76,18 +74,21 @@ public:
RotatingFileSink& operator=(const RotatingFileSink&) = delete;
RotatingFileSink& operator=(RotatingFileSink&&) = delete;
- virtual void log(const spdlog::details::log_msg& msg) override
+ virtual void Log(const LogMessage& Msg) override
{
ZEN_MEMSCOPE(ELLMTag::Logging);
try
{
- spdlog::memory_buf_t Formatted;
- if (TrySinkIt(msg, Formatted))
+ MemoryBuffer Formatted;
+ if (TrySinkIt(Msg, Formatted))
{
return;
}
- while (true)
+
+ // This intentionally has no limit on the number of retries, see
+ // comment above.
+ for (;;)
{
{
RwLock::ExclusiveLockScope RotateLock(m_Lock);
@@ -113,7 +114,7 @@ public:
// Silently eat errors
}
}
- virtual void flush() override
+ virtual void Flush() override
{
if (!m_NeedFlush)
{
@@ -138,28 +139,14 @@ public:
m_NeedFlush = false;
}
- virtual void set_pattern(const std::string& pattern) override
+ virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) override
{
ZEN_MEMSCOPE(ELLMTag::Logging);
try
{
RwLock::ExclusiveLockScope _(m_Lock);
- m_Formatter = spdlog::details::make_unique<spdlog::pattern_formatter>(pattern);
- }
- catch (const std::exception&)
- {
- // Silently eat errors
- }
- }
- virtual void set_formatter(std::unique_ptr<spdlog::formatter> sink_formatter) override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- try
- {
- RwLock::ExclusiveLockScope _(m_Lock);
- m_Formatter = std::move(sink_formatter);
+ m_Formatter = std::move(InFormatter);
}
catch (const std::exception&)
{
@@ -186,11 +173,17 @@ private:
return;
}
- // If we fail to rotate, try extending the current log file
m_CurrentSize = m_CurrentFile.FileSize(OutEc);
+ if (OutEc)
+ {
+ // FileSize failed but we have an open file — reset to 0
+ // so we can at least attempt writes from the start
+ m_CurrentSize = 0;
+ OutEc.clear();
+ }
}
- bool TrySinkIt(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& OutFormatted)
+ bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
@@ -199,15 +192,15 @@ private:
{
return false;
}
- m_Formatter->format(msg, OutFormatted);
- size_t add_size = OutFormatted.size();
- size_t write_pos = m_CurrentSize.fetch_add(add_size);
- if (write_pos + add_size > m_MaxSize)
+ m_Formatter->Format(Msg, OutFormatted);
+ size_t AddSize = OutFormatted.size();
+ size_t WritePos = m_CurrentSize.fetch_add(AddSize);
+ if (WritePos + AddSize > m_MaxSize)
{
return false;
}
std::error_code Ec;
- m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), write_pos, Ec);
+ m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec);
if (Ec)
{
return false;
@@ -216,7 +209,7 @@ private:
return true;
}
- bool TrySinkIt(const spdlog::memory_buf_t& Formatted)
+ bool TrySinkIt(const MemoryBuffer& Formatted)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
@@ -225,15 +218,15 @@ private:
{
return false;
}
- size_t add_size = Formatted.size();
- size_t write_pos = m_CurrentSize.fetch_add(add_size);
- if (write_pos + add_size > m_MaxSize)
+ size_t AddSize = Formatted.size();
+ size_t WritePos = m_CurrentSize.fetch_add(AddSize);
+ if (WritePos + AddSize > m_MaxSize)
{
return false;
}
std::error_code Ec;
- m_CurrentFile.Write(Formatted.data(), Formatted.size(), write_pos, Ec);
+ m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec);
if (Ec)
{
return false;
@@ -242,14 +235,14 @@ private:
return true;
}
- RwLock m_Lock;
- const std::filesystem::path m_BaseFilename;
- std::unique_ptr<spdlog::formatter> m_Formatter;
- std::atomic_size_t m_CurrentSize;
- const std::size_t m_MaxSize;
- const std::size_t m_MaxFiles;
- BasicFile m_CurrentFile;
- std::atomic<bool> m_NeedFlush = false;
+ RwLock m_Lock;
+ const std::filesystem::path m_BaseFilename;
+ const std::size_t m_MaxSize;
+ const std::size_t m_MaxFiles;
+ std::unique_ptr<Formatter> m_Formatter;
+ std::atomic_size_t m_CurrentSize;
+ BasicFile m_CurrentFile;
+ std::atomic<bool> m_NeedFlush = false;
};
} // namespace zen::logging
diff --git a/src/zenutil/include/zenutil/logging/testformatter.h b/src/zenutil/include/zenutil/logging/testformatter.h
deleted file mode 100644
index 0b0c191fb..000000000
--- a/src/zenutil/include/zenutil/logging/testformatter.h
+++ /dev/null
@@ -1,160 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/memory/llm.h>
-
-#include <spdlog/spdlog.h>
-
-namespace zen::logging {
-
-class full_test_formatter final : public spdlog::formatter
-{
-public:
- full_test_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId)
- {
- }
-
- virtual std::unique_ptr<formatter> clone() const override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
- return std::make_unique<full_test_formatter>(m_LogId, m_Epoch);
- }
-
- static constexpr bool UseDate = false;
-
- virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& dest) override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- using namespace std::literals;
-
- if constexpr (UseDate)
- {
- auto secs = std::chrono::duration_cast<std::chrono::seconds>(msg.time.time_since_epoch());
- if (secs != m_LastLogSecs)
- {
- m_CachedTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time));
- m_LastLogSecs = secs;
- }
- }
-
- const auto& tm_time = m_CachedTm;
-
- // cache the date/time part for the next second.
- auto duration = msg.time - m_Epoch;
- auto secs = std::chrono::duration_cast<std::chrono::seconds>(duration);
-
- if (m_CacheTimestamp != secs)
- {
- RwLock::ExclusiveLockScope _(m_TimestampLock);
-
- m_CachedDatetime.clear();
- m_CachedDatetime.push_back('[');
-
- if constexpr (UseDate)
- {
- spdlog::details::fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
- m_CachedDatetime.push_back('-');
-
- spdlog::details::fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
- m_CachedDatetime.push_back('-');
-
- spdlog::details::fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
- m_CachedDatetime.push_back(' ');
-
- spdlog::details::fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
-
- spdlog::details::fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
-
- spdlog::details::fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
- }
- else
- {
- int Count = int(secs.count());
-
- const int LogSecs = Count % 60;
- Count /= 60;
-
- const int LogMins = Count % 60;
- Count /= 60;
-
- const int LogHours = Count;
-
- spdlog::details::fmt_helper::pad2(LogHours, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(LogMins, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
- spdlog::details::fmt_helper::pad2(LogSecs, m_CachedDatetime);
- }
-
- m_CachedDatetime.push_back('.');
-
- m_CacheTimestamp = secs;
- }
-
- {
- RwLock::SharedLockScope _(m_TimestampLock);
- dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
- }
-
- auto millis = spdlog::details::fmt_helper::time_fraction<std::chrono::milliseconds>(msg.time);
- spdlog::details::fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
- dest.push_back(']');
- dest.push_back(' ');
-
- if (!m_LogId.empty())
- {
- dest.push_back('[');
- spdlog::details::fmt_helper::append_string_view(m_LogId, dest);
- dest.push_back(']');
- dest.push_back(' ');
- }
-
- // append logger name if exists
- if (msg.logger_name.size() > 0)
- {
- dest.push_back('[');
- spdlog::details::fmt_helper::append_string_view(msg.logger_name, dest);
- dest.push_back(']');
- dest.push_back(' ');
- }
-
- dest.push_back('[');
- // wrap the level name with color
- msg.color_range_start = dest.size();
- spdlog::details::fmt_helper::append_string_view(spdlog::level::to_string_view(msg.level), dest);
- msg.color_range_end = dest.size();
- dest.push_back(']');
- dest.push_back(' ');
-
- // add source location if present
- if (!msg.source.empty())
- {
- dest.push_back('[');
- const char* filename =
- spdlog::details::short_filename_formatter<spdlog::details::null_scoped_padder>::basename(msg.source.filename);
- spdlog::details::fmt_helper::append_string_view(filename, dest);
- dest.push_back(':');
- spdlog::details::fmt_helper::append_int(msg.source.line, dest);
- dest.push_back(']');
- dest.push_back(' ');
- }
-
- spdlog::details::fmt_helper::append_string_view(msg.payload, dest);
- spdlog::details::fmt_helper::append_string_view("\n"sv, dest);
- }
-
-private:
- std::chrono::time_point<std::chrono::system_clock> m_Epoch;
- std::tm m_CachedTm;
- std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)};
- std::chrono::seconds m_CacheTimestamp{std::chrono::seconds(87654321)};
- spdlog::memory_buf_t m_CachedDatetime;
- std::string m_LogId;
- RwLock m_TimestampLock;
-};
-
-} // namespace zen::logging
diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h
index d0402640b..2a8617162 100644
--- a/src/zenutil/include/zenutil/zenserverprocess.h
+++ b/src/zenutil/include/zenutil/zenserverprocess.h
@@ -42,9 +42,13 @@ public:
std::filesystem::path GetTestRootDir(std::string_view Path);
inline bool IsInitialized() const { return m_IsInitialized; }
inline bool IsTestEnvironment() const { return m_IsTestInstance; }
+ inline bool IsHubEnvironment() const { return m_IsHubInstance; }
inline std::string_view GetServerClass() const { return m_ServerClass; }
inline uint16_t GetNewPortNumber() { return m_NextPortNumber.fetch_add(1); }
+ void SetPassthroughOutput(bool Enable) { m_PassthroughOutput = Enable; }
+ bool IsPassthroughOutput() const { return m_PassthroughOutput; }
+
// The defaults will work for a single root process only. For hierarchical
// setups (e.g., hub managing storage servers), we need to be able to
// allocate distinct child IDs and ports to avoid overlap/conflicts.
@@ -54,9 +58,10 @@ public:
private:
std::filesystem::path m_ProgramBaseDir;
std::filesystem::path m_ChildProcessBaseDir;
- bool m_IsInitialized = false;
- bool m_IsTestInstance = false;
- bool m_IsHubInstance = false;
+ bool m_IsInitialized = false;
+ bool m_IsTestInstance = false;
+ bool m_IsHubInstance = false;
+ bool m_PassthroughOutput = false;
std::string m_ServerClass;
std::atomic_uint16_t m_NextPortNumber{20000};
};
@@ -79,6 +84,7 @@ struct ZenServerInstance
{
kStorageServer, // default
kHubServer,
+ kComputeServer,
};
ZenServerInstance(ZenServerEnvironment& TestEnvironment, ServerMode Mode = ServerMode::kStorageServer);
@@ -96,9 +102,12 @@ struct ZenServerInstance
inline int GetPid() const { return m_Process.Pid(); }
inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; }
void* GetProcessHandle() const { return m_Process.Handle(); }
- bool IsRunning();
- bool Terminate();
- std::string GetLogOutput() const;
+#if ZEN_PLATFORM_WINDOWS
+ void SetJobObject(JobObject* Job) { m_JobObject = Job; }
+#endif
+ bool IsRunning();
+ bool Terminate();
+ std::string GetLogOutput() const;
inline ServerMode GetServerMode() const { return m_ServerMode; }
@@ -147,6 +156,9 @@ private:
std::string m_Name;
std::filesystem::path m_OutputCapturePath;
std::filesystem::path m_ServerExecutablePath;
+#if ZEN_PLATFORM_WINDOWS
+ JobObject* m_JobObject = nullptr;
+#endif
void CreateShutdownEvent(int BasePort);
void SpawnServer(int BasePort, std::string_view AdditionalServerArgs, int WaitTimeoutMs);
diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging.cpp
index 806b96d52..1258ca155 100644
--- a/src/zenutil/logging.cpp
+++ b/src/zenutil/logging.cpp
@@ -2,18 +2,15 @@
#include "zenutil/logging.h"
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <spdlog/async.h>
-#include <spdlog/async_logger.h>
-#include <spdlog/sinks/ansicolor_sink.h>
-#include <spdlog/sinks/msvc_sink.h>
-#include <spdlog/spdlog.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
#include <zencore/callstack.h>
#include <zencore/compactbinary.h>
#include <zencore/filesystem.h>
#include <zencore/logging.h>
+#include <zencore/logging/ansicolorsink.h>
+#include <zencore/logging/asyncsink.h>
+#include <zencore/logging/logger.h>
+#include <zencore/logging/msvcsink.h>
+#include <zencore/logging/registry.h>
#include <zencore/memory/llm.h>
#include <zencore/string.h>
#include <zencore/timer.h>
@@ -27,9 +24,9 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
static bool g_IsLoggingInitialized;
-spdlog::sink_ptr g_FileSink;
+logging::SinkPtr g_FileSink;
-spdlog::sink_ptr
+logging::SinkPtr
GetFileSink()
{
return g_FileSink;
@@ -52,33 +49,9 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
zen::logging::InitializeLogging();
zen::logging::EnableVTMode();
- bool IsAsync = LogOptions.AllowAsync;
-
- if (LogOptions.IsDebug)
- {
- IsAsync = false;
- }
-
- if (LogOptions.IsTest)
- {
- IsAsync = false;
- }
-
- if (IsAsync)
- {
- const int QueueSize = 8192;
- const int ThreadCount = 1;
- spdlog::init_thread_pool(QueueSize, ThreadCount, [&] { SetCurrentThreadName("spdlog_async"); });
-
- auto AsyncSink = spdlog::create_async<spdlog::sinks::ansicolor_stdout_sink_mt>("main");
- zen::logging::SetDefault("main");
- }
-
// Sinks
- spdlog::sink_ptr FileSink;
-
- // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance
+ logging::SinkPtr FileSink;
if (!LogOptions.AbsLogFile.empty())
{
@@ -87,17 +60,17 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
zen::CreateDirectories(LogOptions.AbsLogFile.parent_path());
}
- FileSink = std::make_shared<zen::logging::RotatingFileSink>(LogOptions.AbsLogFile,
- /* max size */ 128 * 1024 * 1024,
- /* max files */ 16,
- /* rotate on open */ true);
+ FileSink = logging::SinkPtr(new zen::logging::RotatingFileSink(LogOptions.AbsLogFile,
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ true));
if (LogOptions.AbsLogFile.extension() == ".json")
{
- FileSink->set_formatter(std::make_unique<logging::json_formatter>(LogOptions.LogId));
+ FileSink->SetFormatter(std::make_unique<logging::JsonFormatter>(LogOptions.LogId));
}
else
{
- FileSink->set_formatter(std::make_unique<logging::full_formatter>(LogOptions.LogId)); // this will have a date prefix
+ FileSink->SetFormatter(std::make_unique<logging::FullFormatter>(LogOptions.LogId)); // this will have a date prefix
}
}
@@ -127,7 +100,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
Message.push_back('\0');
// We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log
- ZEN_LOG(Log(), zen::logging::level::Critical, "{}", Message.data());
+ ZEN_LOG(Log(), zen::logging::Critical, "{}", Message.data());
zen::logging::FlushLogging();
}
catch (const std::exception&)
@@ -143,9 +116,9 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
// Default
LoggerRef DefaultLogger = zen::logging::Default();
- auto& Sinks = DefaultLogger.SpdLogger->sinks();
- Sinks.clear();
+ // Collect sinks into a local vector first so we can optionally wrap them
+ std::vector<logging::SinkPtr> Sinks;
if (LogOptions.NoConsoleOutput)
{
@@ -153,10 +126,10 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
}
else
{
- auto ConsoleSink = std::make_shared<spdlog::sinks::ansicolor_stdout_sink_mt>();
+ logging::SinkPtr ConsoleSink(new logging::AnsiColorStdoutSink());
if (LogOptions.QuietConsole)
{
- ConsoleSink->set_level(spdlog::level::warn);
+ ConsoleSink->SetLevel(logging::Warn);
}
Sinks.push_back(ConsoleSink);
}
@@ -169,40 +142,54 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
#if ZEN_PLATFORM_WINDOWS
if (zen::IsDebuggerPresent() && LogOptions.IsDebug)
{
- auto DebugSink = std::make_shared<spdlog::sinks::msvc_sink_mt>();
- DebugSink->set_level(spdlog::level::debug);
+ logging::SinkPtr DebugSink(new logging::MsvcSink());
+ DebugSink->SetLevel(logging::Debug);
Sinks.push_back(DebugSink);
}
#endif
- spdlog::set_error_handler([](const std::string& msg) {
- if (msg == std::bad_alloc().what())
- {
- // Don't report out of memory in spdlog as we usually log in response to errors which will cause another OOM crashing the
- // program
- return;
- }
- // Bypass zen logging wrapping to reduce potential other error sources
- if (auto ErrLogger = zen::logging::ErrorLog())
+ bool IsAsync = LogOptions.AllowAsync && !LogOptions.IsDebug && !LogOptions.IsTest;
+
+ if (IsAsync)
+ {
+ std::vector<logging::SinkPtr> AsyncSinks;
+ AsyncSinks.emplace_back(new logging::AsyncSink(std::move(Sinks)));
+ DefaultLogger->SetSinks(std::move(AsyncSinks));
+ }
+ else
+ {
+ DefaultLogger->SetSinks(std::move(Sinks));
+ }
+
+ static struct : logging::ErrorHandler
+ {
+ void HandleError(const std::string_view& ErrorMsg) override
{
+ if (ErrorMsg == std::bad_alloc().what())
+ {
+ return;
+ }
+ static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"};
+ if (auto ErrLogger = zen::logging::ErrorLog())
+ {
+ try
+ {
+ ErrLogger->Log(ErrorPoint, fmt::make_format_args(ErrorMsg));
+ }
+ catch (const std::exception&)
+ {
+ }
+ }
try
{
- ErrLogger.SpdLogger->log(spdlog::level::err, msg);
+ Log()->Log(ErrorPoint, fmt::make_format_args(ErrorMsg));
}
catch (const std::exception&)
{
- // Just ignore any errors when in error handler
}
}
- try
- {
- Log().SpdLogger->error(msg);
- }
- catch (const std::exception&)
- {
- // Just ignore any errors when in error handler
- }
- });
+ } s_ErrorHandler;
+ logging::Registry::Instance().SetErrorHandler(&s_ErrorHandler);
g_FileSink = std::move(FileSink);
}
@@ -212,41 +199,47 @@ FinishInitializeLogging(const LoggingOptions& LogOptions)
{
ZEN_MEMSCOPE(ELLMTag::Logging);
- logging::level::LogLevel LogLevel = logging::level::Info;
+ logging::LogLevel LogLevel = logging::Info;
if (LogOptions.IsDebug)
{
- LogLevel = logging::level::Debug;
+ LogLevel = logging::Debug;
}
if (LogOptions.IsTest || LogOptions.IsVerbose)
{
- LogLevel = logging::level::Trace;
+ LogLevel = logging::Trace;
}
// Configure all registered loggers according to settings
logging::RefreshLogLevels(LogLevel);
- spdlog::flush_on(spdlog::level::err);
- spdlog::flush_every(std::chrono::seconds{2});
- spdlog::set_formatter(std::make_unique<logging::full_formatter>(
+ logging::Registry::Instance().FlushOn(logging::Err);
+ logging::Registry::Instance().FlushEvery(std::chrono::seconds{2});
+ logging::Registry::Instance().SetFormatter(std::make_unique<logging::FullFormatter>(
LogOptions.LogId,
std::chrono::system_clock::now() - std::chrono::milliseconds(GetTimeSinceProcessStart()))); // default to duration prefix
+ // If the console logger was initialized before, the above will change the output format
+ // so we need to reset it
+
+ logging::ResetConsoleLog();
+
if (g_FileSink)
{
if (LogOptions.AbsLogFile.extension() == ".json")
{
- g_FileSink->set_formatter(std::make_unique<logging::json_formatter>(LogOptions.LogId));
+ g_FileSink->SetFormatter(std::make_unique<logging::JsonFormatter>(LogOptions.LogId));
}
else
{
- g_FileSink->set_formatter(std::make_unique<logging::full_formatter>(LogOptions.LogId)); // this will have a date prefix
+ g_FileSink->SetFormatter(std::make_unique<logging::FullFormatter>(LogOptions.LogId)); // this will have a date prefix
}
const std::string StartLogTime = zen::DateTime::Now().ToIso8601();
- spdlog::apply_all([&](auto Logger) { Logger->info("log starting at {}", StartLogTime); });
+ static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"};
+ logging::Registry::Instance().ApplyAll([&](auto Logger) { Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); });
}
g_IsLoggingInitialized = true;
@@ -263,7 +256,7 @@ ShutdownLogging()
zen::logging::ShutdownLogging();
- g_FileSink.reset();
+ g_FileSink = nullptr;
}
} // namespace zen
diff --git a/src/zenutil/rpcrecording.cpp b/src/zenutil/rpcrecording.cpp
index 54f27dee7..28a0091cb 100644
--- a/src/zenutil/rpcrecording.cpp
+++ b/src/zenutil/rpcrecording.cpp
@@ -1119,7 +1119,7 @@ rpcrecord_forcelink()
{
}
-TEST_SUITE_BEGIN("rpc.recording");
+TEST_SUITE_BEGIN("util.rpcrecording");
TEST_CASE("rpc.record")
{
diff --git a/src/zenutil/wildcard.cpp b/src/zenutil/wildcard.cpp
index 7a44c0498..7f2f77780 100644
--- a/src/zenutil/wildcard.cpp
+++ b/src/zenutil/wildcard.cpp
@@ -118,6 +118,8 @@ wildcard_forcelink()
{
}
+TEST_SUITE_BEGIN("util.wildcard");
+
TEST_CASE("Wildcard")
{
CHECK(MatchWildcard("*.*", "normal.txt", true));
@@ -151,5 +153,7 @@ TEST_CASE("Wildcard")
CHECK(MatchWildcard("*.d", "dir/path.d", true));
}
+TEST_SUITE_END();
+
#endif
} // namespace zen
diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua
index bc33adf9e..1d5be5977 100644
--- a/src/zenutil/xmake.lua
+++ b/src/zenutil/xmake.lua
@@ -6,7 +6,7 @@ target('zenutil')
add_headerfiles("**.h")
add_files("**.cpp")
add_includedirs("include", {public=true})
- add_deps("zencore", "zenhttp", "spdlog")
+ add_deps("zencore", "zenhttp")
add_deps("cxxopts")
add_deps("robin-map")
diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp
index ef2a4fda5..b09c2d89a 100644
--- a/src/zenutil/zenserverprocess.cpp
+++ b/src/zenutil/zenserverprocess.cpp
@@ -787,6 +787,8 @@ ToString(ZenServerInstance::ServerMode Mode)
return "storage"sv;
case ZenServerInstance::ServerMode::kHubServer:
return "hub"sv;
+ case ZenServerInstance::ServerMode::kComputeServer:
+ return "compute"sv;
default:
return "invalid"sv;
}
@@ -808,6 +810,10 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs,
{
CommandLine << " hub";
}
+ else if (m_ServerMode == ServerMode::kComputeServer)
+ {
+ CommandLine << " compute";
+ }
CommandLine << " --child-id " << ChildEventName;
@@ -829,10 +835,18 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs,
const std::filesystem::path BaseDir = m_Env.ProgramBaseDir();
const std::filesystem::path Executable =
m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath;
- const std::filesystem::path OutputPath =
- OpenConsole ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log");
- CreateProcOptions CreateOptions = {.WorkingDirectory = &CurrentDirectory, .Flags = CreationFlags, .StdoutFile = OutputPath};
- CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
+ const std::filesystem::path OutputPath = (OpenConsole || m_Env.IsPassthroughOutput())
+ ? std::filesystem::path{}
+ : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log");
+ CreateProcOptions CreateOptions = {
+ .WorkingDirectory = &CurrentDirectory,
+ .Flags = CreationFlags,
+ .StdoutFile = OutputPath,
+#if ZEN_PLATFORM_WINDOWS
+ .AssignToJob = m_JobObject,
+#endif
+ };
+ CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
#if ZEN_PLATFORM_WINDOWS
if (!ChildPid)
{
@@ -841,6 +855,12 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs,
{
ZEN_DEBUG("Regular spawn failed - spawning elevated server");
CreateOptions.Flags |= CreateProcOptions::Flag_Elevated;
+ // ShellExecuteEx (used by the elevated path) does not support job object assignment
+ if (CreateOptions.AssignToJob)
+ {
+ ZEN_WARN("Elevated process spawn does not support job object assignment; child will not be auto-terminated on parent exit");
+ CreateOptions.AssignToJob = nullptr;
+ }
ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
}
else
@@ -934,7 +954,8 @@ ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerAr
CommandLine << " " << AdditionalServerArgs;
}
- SpawnServerInternal(ChildId, CommandLine, !IsTest, WaitTimeoutMs);
+ const bool OpenConsole = !IsTest && !m_Env.IsHubEnvironment();
+ SpawnServerInternal(ChildId, CommandLine, OpenConsole, WaitTimeoutMs);
}
void
diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp
index 51c1ee72e..291dbeadd 100644
--- a/src/zenutil/zenutil.cpp
+++ b/src/zenutil/zenutil.cpp
@@ -5,7 +5,7 @@
#if ZEN_WITH_TESTS
# include <zenutil/rpcrecording.h>
-# include <zenutil/commandlineoptions.h>
+# include <zenutil/config/commandlineoptions.h>
# include <zenutil/wildcard.h>
namespace zen {
diff --git a/src/zenvfs/xmake.lua b/src/zenvfs/xmake.lua
index 7f790c2d4..47665a5d5 100644
--- a/src/zenvfs/xmake.lua
+++ b/src/zenvfs/xmake.lua
@@ -6,5 +6,5 @@ target('zenvfs')
add_headerfiles("**.h")
add_files("**.cpp")
add_includedirs("include", {public=true})
- add_deps("zencore", "spdlog")
+ add_deps("zencore")