aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/zen/cmds/builds_cmd.cpp365
-rw-r--r--src/zen/cmds/builds_cmd.h3
-rw-r--r--src/zen/cmds/compute_cmd.cpp96
-rw-r--r--src/zen/cmds/compute_cmd.h53
-rw-r--r--src/zen/cmds/exec_cmd.cpp92
-rw-r--r--src/zen/cmds/exec_cmd.h1
-rw-r--r--src/zen/cmds/projectstore_cmd.cpp15
-rw-r--r--src/zen/cmds/projectstore_cmd.h2
-rw-r--r--src/zen/cmds/service_cmd.cpp2
-rw-r--r--src/zen/cmds/ui_cmd.cpp2
-rw-r--r--src/zen/cmds/wipe_cmd.cpp3
-rw-r--r--src/zen/cmds/workspaces_cmd.cpp4
-rw-r--r--src/zen/progressbar.cpp279
-rw-r--r--src/zen/progressbar.h51
-rw-r--r--src/zen/zen.cpp7
-rw-r--r--src/zenbase/include/zenbase/atomic.h74
-rw-r--r--src/zenbase/include/zenbase/refcount.h157
-rw-r--r--src/zencompute/CLAUDE.md17
-rw-r--r--src/zencompute/cloudmetadata.cpp8
-rw-r--r--src/zencompute/computeservice.cpp449
-rw-r--r--src/zencompute/httpcomputeservice.cpp274
-rw-r--r--src/zencompute/httporchestrator.cpp142
-rw-r--r--src/zencompute/include/zencompute/cloudmetadata.h2
-rw-r--r--src/zencompute/include/zencompute/computeservice.h8
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h11
-rw-r--r--src/zencompute/include/zencompute/httporchestrator.h23
-rw-r--r--src/zencompute/include/zencompute/mockimds.h2
-rw-r--r--src/zencompute/include/zencompute/orchestratorservice.h19
-rw-r--r--src/zencompute/include/zencompute/provisionerstate.h38
-rw-r--r--src/zencompute/orchestratorservice.cpp33
-rw-r--r--src/zencompute/pathvalidation.h118
-rw-r--r--src/zencompute/runners/functionrunner.cpp132
-rw-r--r--src/zencompute/runners/functionrunner.h27
-rw-r--r--src/zencompute/runners/linuxrunner.cpp22
-rw-r--r--src/zencompute/runners/localrunner.cpp40
-rw-r--r--src/zencompute/runners/localrunner.h4
-rw-r--r--src/zencompute/runners/macrunner.cpp18
-rw-r--r--src/zencompute/runners/managedrunner.cpp279
-rw-r--r--src/zencompute/runners/managedrunner.h64
-rw-r--r--src/zencompute/runners/remotehttprunner.cpp366
-rw-r--r--src/zencompute/runners/remotehttprunner.h14
-rw-r--r--src/zencompute/runners/windowsrunner.cpp105
-rw-r--r--src/zencompute/runners/windowsrunner.h1
-rw-r--r--src/zencompute/runners/winerunner.cpp6
-rw-r--r--src/zencore/compactbinaryjson.cpp34
-rw-r--r--src/zencore/compactbinarypackage.cpp122
-rw-r--r--src/zencore/crashhandler.cpp2
-rw-r--r--src/zencore/filesystem.cpp163
-rw-r--r--src/zencore/include/zencore/compactbinarypackage.h19
-rw-r--r--src/zencore/include/zencore/filesystem.h4
-rw-r--r--src/zencore/include/zencore/fmtutils.h65
-rw-r--r--src/zencore/include/zencore/hashutils.h19
-rw-r--r--src/zencore/include/zencore/iobuffer.h19
-rw-r--r--src/zencore/include/zencore/iohash.h1
-rw-r--r--src/zencore/include/zencore/logbase.h13
-rw-r--r--src/zencore/include/zencore/logging.h93
-rw-r--r--src/zencore/include/zencore/logging/broadcastsink.h4
-rw-r--r--src/zencore/include/zencore/logging/helpers.h2
-rw-r--r--src/zencore/include/zencore/logging/logmsg.h49
-rw-r--r--src/zencore/include/zencore/mpscqueue.h2
-rw-r--r--src/zencore/include/zencore/process.h23
-rw-r--r--src/zencore/include/zencore/sharedbuffer.h8
-rw-r--r--src/zencore/include/zencore/string.h61
-rw-r--r--src/zencore/include/zencore/system.h5
-rw-r--r--src/zencore/include/zencore/testutils.h22
-rw-r--r--src/zencore/include/zencore/thread.h2
-rw-r--r--src/zencore/include/zencore/zencore.h2
-rw-r--r--src/zencore/iobuffer.cpp13
-rw-r--r--src/zencore/jobqueue.cpp24
-rw-r--r--src/zencore/logging.cpp31
-rw-r--r--src/zencore/logging/ansicolorsink.cpp68
-rw-r--r--src/zencore/logging/registry.cpp4
-rw-r--r--src/zencore/memory/memory.cpp4
-rw-r--r--src/zencore/process.cpp211
-rw-r--r--src/zencore/refcount.cpp24
-rw-r--r--src/zencore/sentryintegration.cpp2
-rw-r--r--src/zencore/sharedbuffer.cpp2
-rw-r--r--src/zencore/string.cpp62
-rw-r--r--src/zencore/system.cpp97
-rw-r--r--src/zencore/testing.cpp2
-rw-r--r--src/zencore/testutils.cpp12
-rw-r--r--src/zencore/thread.cpp4
-rw-r--r--src/zencore/trace.cpp3
-rw-r--r--src/zencore/zencore.cpp2
-rw-r--r--src/zenhorde/README.md17
-rw-r--r--src/zenhorde/hordeagent.cpp561
-rw-r--r--src/zenhorde/hordeagent.h127
-rw-r--r--src/zenhorde/hordeagentmessage.cpp581
-rw-r--r--src/zenhorde/hordeagentmessage.h153
-rw-r--r--src/zenhorde/hordebundle.cpp63
-rw-r--r--src/zenhorde/hordeclient.cpp86
-rw-r--r--src/zenhorde/hordecomputebuffer.cpp454
-rw-r--r--src/zenhorde/hordecomputebuffer.h136
-rw-r--r--src/zenhorde/hordecomputechannel.cpp37
-rw-r--r--src/zenhorde/hordecomputechannel.h32
-rw-r--r--src/zenhorde/hordecomputesocket.cpp410
-rw-r--r--src/zenhorde/hordecomputesocket.h109
-rw-r--r--src/zenhorde/hordeconfig.cpp16
-rw-r--r--src/zenhorde/hordeprovisioner.cpp682
-rw-r--r--src/zenhorde/hordetransport.cpp153
-rw-r--r--src/zenhorde/hordetransport.h67
-rw-r--r--src/zenhorde/hordetransportaes.cpp718
-rw-r--r--src/zenhorde/hordetransportaes.h51
-rw-r--r--src/zenhorde/include/zenhorde/hordeclient.h32
-rw-r--r--src/zenhorde/include/zenhorde/hordeconfig.h37
-rw-r--r--src/zenhorde/include/zenhorde/hordeprovisioner.h92
-rw-r--r--src/zenhttp/asynchttpclient_test.cpp315
-rw-r--r--src/zenhttp/auth/authmgr.cpp12
-rw-r--r--src/zenhttp/clients/asynchttpclient.cpp1033
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp303
-rw-r--r--src/zenhttp/clients/httpclientcurl.h1
-rw-r--r--src/zenhttp/clients/httpclientcurlhelpers.h298
-rw-r--r--src/zenhttp/httpclient.cpp2
-rw-r--r--src/zenhttp/httpclient_test.cpp41
-rw-r--r--src/zenhttp/httpclientauth.cpp20
-rw-r--r--src/zenhttp/httpserver.cpp125
-rw-r--r--src/zenhttp/include/zenhttp/asynchttpclient.h123
-rw-r--r--src/zenhttp/include/zenhttp/httpclientauth.h3
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h14
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h29
-rw-r--r--src/zenhttp/include/zenhttp/httpstats.h6
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h8
-rw-r--r--src/zenhttp/include/zenhttp/localrefpolicy.h21
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h25
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h2
-rw-r--r--src/zenhttp/monitoring/httpstats.cpp3
-rw-r--r--src/zenhttp/packageformat.cpp548
-rw-r--r--src/zenhttp/servers/httpasio.cpp38
-rw-r--r--src/zenhttp/servers/httpparser.cpp414
-rw-r--r--src/zenhttp/servers/httpparser.h8
-rw-r--r--src/zenhttp/servers/httpplugin.cpp32
-rw-r--r--src/zenhttp/servers/httpsys.cpp45
-rw-r--r--src/zenhttp/servers/wsasio.cpp2
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp6
-rw-r--r--src/zenhttp/servers/wstest.cpp37
-rw-r--r--src/zenhttp/xmake.lua2
-rw-r--r--src/zenhttp/zenhttp.cpp3
-rw-r--r--src/zennomad/include/zennomad/nomadclient.h6
-rw-r--r--src/zennomad/include/zennomad/nomadprovisioner.h9
-rw-r--r--src/zennomad/nomadclient.cpp38
-rw-r--r--src/zennomad/nomadprocess.cpp2
-rw-r--r--src/zennomad/nomadprovisioner.cpp11
-rw-r--r--src/zenremotestore/builds/buildstoragecache.cpp63
-rw-r--r--src/zenremotestore/builds/buildstorageoperations.cpp6803
-rw-r--r--src/zenremotestore/builds/buildstorageutil.cpp50
-rw-r--r--src/zenremotestore/builds/jupiterbuildstorage.cpp51
-rw-r--r--src/zenremotestore/chunking/chunkblock.cpp274
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h412
-rw-r--r--src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h5
-rw-r--r--src/zenremotestore/include/zenremotestore/chunking/chunkblock.h11
-rw-r--r--src/zenremotestore/include/zenremotestore/operationlogoutput.h76
-rw-r--r--src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h35
-rw-r--r--src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h7
-rw-r--r--src/zenremotestore/jupiter/jupitersession.cpp2
-rw-r--r--src/zenremotestore/operationlogoutput.cpp103
-rw-r--r--src/zenremotestore/projectstore/buildsremoteprojectstore.cpp5
-rw-r--r--src/zenremotestore/projectstore/projectstoreoperations.cpp244
-rw-r--r--src/zenremotestore/projectstore/remoteprojectstore.cpp441
-rw-r--r--src/zenremotestore/zenremotestore.cpp2
-rw-r--r--src/zens3-testbed/main.cpp526
-rw-r--r--src/zens3-testbed/xmake.lua8
-rw-r--r--src/zenserver-test/buildstore-tests.cpp155
-rw-r--r--src/zenserver-test/cache-tests.cpp351
-rw-r--r--src/zenserver-test/compute-tests.cpp642
-rw-r--r--src/zenserver-test/hub-tests.cpp46
-rw-r--r--src/zenserver-test/logging-tests.cpp26
-rw-r--r--src/zenserver-test/objectstore-tests.cpp320
-rw-r--r--src/zenserver-test/process-tests.cpp4
-rw-r--r--src/zenserver-test/projectstore-tests.cpp1123
-rw-r--r--src/zenserver-test/xmake.lua2
-rw-r--r--src/zenserver-test/zenserver-test.cpp10
-rw-r--r--src/zenserver/compute/computeserver.cpp132
-rw-r--r--src/zenserver/compute/computeserver.h16
-rw-r--r--src/zenserver/config/config.cpp23
-rw-r--r--src/zenserver/diag/logging.cpp2
-rw-r--r--src/zenserver/frontend/frontend.cpp2
-rw-r--r--src/zenserver/frontend/html/compute/compute.html925
-rw-r--r--src/zenserver/frontend/html/compute/hub.html2
-rw-r--r--src/zenserver/frontend/html/compute/index.html2
-rw-r--r--src/zenserver/frontend/html/compute/orchestrator.html669
-rw-r--r--src/zenserver/frontend/html/pages/builds.js56
-rw-r--r--src/zenserver/frontend/html/pages/cache.js319
-rw-r--r--src/zenserver/frontend/html/pages/compute.js80
-rw-r--r--src/zenserver/frontend/html/pages/entry.js2
-rw-r--r--src/zenserver/frontend/html/pages/hub.js375
-rw-r--r--src/zenserver/frontend/html/pages/orchestrator.js248
-rw-r--r--src/zenserver/frontend/html/pages/page.js68
-rw-r--r--src/zenserver/frontend/html/pages/projects.js236
-rw-r--r--src/zenserver/frontend/html/pages/start.js166
-rw-r--r--src/zenserver/frontend/html/pages/workspaces.js5
-rw-r--r--src/zenserver/frontend/html/util/widgets.js181
-rw-r--r--src/zenserver/frontend/html/zen.css99
-rw-r--r--src/zenserver/frontend/zipfs.cpp4
-rw-r--r--src/zenserver/hub/README.md17
-rw-r--r--src/zenserver/hub/httphubservice.cpp214
-rw-r--r--src/zenserver/hub/httphubservice.h17
-rw-r--r--src/zenserver/hub/httpproxyhandler.cpp528
-rw-r--r--src/zenserver/hub/httpproxyhandler.h52
-rw-r--r--src/zenserver/hub/hub.cpp1010
-rw-r--r--src/zenserver/hub/hub.h83
-rw-r--r--src/zenserver/hub/hubinstancestate.cpp2
-rw-r--r--src/zenserver/hub/hubinstancestate.h3
-rw-r--r--src/zenserver/hub/hydration.cpp2038
-rw-r--r--src/zenserver/hub/hydration.h36
-rw-r--r--src/zenserver/hub/storageserverinstance.cpp139
-rw-r--r--src/zenserver/hub/storageserverinstance.h56
-rw-r--r--src/zenserver/hub/zenhubserver.cpp399
-rw-r--r--src/zenserver/hub/zenhubserver.h50
-rw-r--r--src/zenserver/main.cpp10
-rw-r--r--src/zenserver/proxy/httptrafficinspector.cpp73
-rw-r--r--src/zenserver/proxy/httptrafficinspector.h10
-rw-r--r--src/zenserver/proxy/zenproxyserver.cpp17
-rw-r--r--src/zenserver/sessions/httpsessions.cpp5
-rw-r--r--src/zenserver/sessions/httpsessions.h2
-rw-r--r--src/zenserver/sessions/sessions.cpp2
-rw-r--r--src/zenserver/storage/buildstore/httpbuildstore.cpp144
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.cpp16
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.h10
-rw-r--r--src/zenserver/storage/localrefpolicy.cpp29
-rw-r--r--src/zenserver/storage/localrefpolicy.h25
-rw-r--r--src/zenserver/storage/objectstore/objectstore.cpp63
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.cpp49
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.h10
-rw-r--r--src/zenserver/storage/storageconfig.cpp18
-rw-r--r--src/zenserver/storage/storageconfig.h5
-rw-r--r--src/zenserver/storage/upstream/upstreamcache.cpp22
-rw-r--r--src/zenserver/storage/zenstorageserver.cpp70
-rw-r--r--src/zenserver/storage/zenstorageserver.h26
-rw-r--r--src/zenserver/xmake.lua2
-rw-r--r--src/zenserver/zenserver.cpp58
-rw-r--r--src/zenserver/zenserver.h10
-rw-r--r--src/zenstore/buildstore/buildstore.cpp6
-rw-r--r--src/zenstore/cache/cachedisklayer.cpp14
-rw-r--r--src/zenstore/cache/structuredcachestore.cpp12
-rw-r--r--src/zenstore/cidstore.cpp12
-rw-r--r--src/zenstore/compactcas.cpp19
-rw-r--r--src/zenstore/filecas.cpp6
-rw-r--r--src/zenstore/gc.cpp10
-rw-r--r--src/zenstore/include/zenstore/cache/cachepolicy.h10
-rw-r--r--src/zenstore/include/zenstore/cidstore.h22
-rw-r--r--src/zenstore/include/zenstore/memorycidstore.h68
-rw-r--r--src/zenstore/include/zenstore/projectstore.h13
-rw-r--r--src/zenstore/include/zenstore/zenstore.h37
-rw-r--r--src/zenstore/memorycidstore.cpp143
-rw-r--r--src/zenstore/projectstore.cpp87
-rw-r--r--src/zenstore/workspaces.cpp23
-rw-r--r--src/zenstore/zenstore.cpp20
-rw-r--r--src/zentelemetry/include/zentelemetry/stats.h10
-rw-r--r--src/zentelemetry/otelmetricsprotozero.h32
-rw-r--r--src/zentest-appstub/zentest-appstub.cpp55
-rw-r--r--src/zenutil/cloud/imdscredentials.cpp6
-rw-r--r--src/zenutil/cloud/minioprocess.cpp4
-rw-r--r--src/zenutil/cloud/mockimds.cpp6
-rw-r--r--src/zenutil/cloud/s3client.cpp118
-rw-r--r--src/zenutil/consoletui.cpp4
-rw-r--r--src/zenutil/consul/consul.cpp305
-rw-r--r--src/zenutil/filesystemutils.cpp (renamed from src/zenremotestore/filesystemutils.cpp)252
-rw-r--r--src/zenutil/include/zenutil/cloud/mockimds.h14
-rw-r--r--src/zenutil/include/zenutil/cloud/s3client.h9
-rw-r--r--src/zenutil/include/zenutil/consoletui.h2
-rw-r--r--src/zenutil/include/zenutil/consul.h17
-rw-r--r--src/zenutil/include/zenutil/filesystemutils.h (renamed from src/zenremotestore/include/zenremotestore/filesystemutils.h)41
-rw-r--r--src/zenutil/include/zenutil/process/subprocessmanager.h26
-rw-r--r--src/zenutil/include/zenutil/progress.h65
-rw-r--r--src/zenutil/include/zenutil/sessionsclient.h6
-rw-r--r--src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h8
-rw-r--r--src/zenutil/include/zenutil/zenserverprocess.h1
-rw-r--r--src/zenutil/logging/fullformatter.cpp30
-rw-r--r--src/zenutil/logging/jsonformatter.cpp2
-rw-r--r--src/zenutil/logging/logging.cpp6
-rw-r--r--src/zenutil/logging/rotatingfilesink.cpp2
-rw-r--r--src/zenutil/process/asyncpipereader.cpp6
-rw-r--r--src/zenutil/process/subprocessmanager.cpp262
-rw-r--r--src/zenutil/progress.cpp99
-rw-r--r--src/zenutil/sessionsclient.cpp11
-rw-r--r--src/zenutil/splitconsole/logstreamlistener.cpp14
-rw-r--r--src/zenutil/zenserverprocess.cpp45
-rw-r--r--src/zenutil/zenutil.cpp4
278 files changed, 22531 insertions, 13088 deletions
diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp
index 93108dd47..3373506f2 100644
--- a/src/zen/cmds/builds_cmd.cpp
+++ b/src/zen/cmds/builds_cmd.cpp
@@ -36,10 +36,10 @@
#include <zenremotestore/chunking/chunkedfile.h>
#include <zenremotestore/chunking/chunkingcache.h>
#include <zenremotestore/chunking/chunkingcontroller.h>
-#include <zenremotestore/filesystemutils.h>
#include <zenremotestore/jupiter/jupiterhost.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenremotestore/transferthreadworkers.h>
+#include <zenutil/filesystemutils.h>
+#include <zenutil/progress.h>
#include <zenutil/wildcard.h>
#include <zenutil/workerpools.h>
#include <zenutil/zenserverprocess.h>
@@ -387,7 +387,8 @@ namespace builds_impl {
return CleanAndRemoveDirectory(WorkerPool, AbortFlag, PauseFlag, Directory);
}
- void ValidateBuildPart(OperationLogOutput& Output,
+ void ValidateBuildPart(LoggerRef Log,
+ ProgressBase& Progress,
TransferThreadWorkers& Workers,
BuildStorageBase& Storage,
const Oid& BuildId,
@@ -398,7 +399,8 @@ namespace builds_impl {
ProgressBar::SetLogOperationName(ProgressMode, "Validate Part");
- BuildsOperationValidateBuildPart ValidateOp(Output,
+ BuildsOperationValidateBuildPart ValidateOp(Log,
+ Progress,
Storage,
AbortFlag,
PauseFlag,
@@ -434,7 +436,8 @@ namespace builds_impl {
const std::vector<std::string>& ExcludeExtensions = DefaultExcludeExtensions;
};
- std::vector<std::pair<Oid, std::string>> UploadFolder(OperationLogOutput& Output,
+ std::vector<std::pair<Oid, std::string>> UploadFolder(LoggerRef Log,
+ ProgressBase& Progress,
TransferThreadWorkers& Workers,
StorageInstance& Storage,
const Oid& BuildId,
@@ -452,7 +455,8 @@ namespace builds_impl {
Stopwatch UploadTimer;
BuildsOperationUploadFolder UploadOp(
- Output,
+ Log,
+ Progress,
Storage,
AbortFlag,
PauseFlag,
@@ -1232,7 +1236,6 @@ namespace builds_impl {
EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::Mixed;
bool CleanTargetFolder = false;
bool PostDownloadVerify = false;
- bool PrimeCacheOnly = false;
bool EnableOtherDownloadsScavenging = true;
bool EnableTargetFolderScavenging = true;
bool AllowFileClone = true;
@@ -1244,7 +1247,8 @@ namespace builds_impl {
std::vector<std::string> ExcludeFolders = DefaultExcludeFolders;
};
- void DownloadFolder(OperationLogOutput& Output,
+ void DownloadFolder(LoggerRef InLog,
+ ProgressBase& Progress,
TransferThreadWorkers& Workers,
StorageInstance& Storage,
const BuildStorageCache::Statistics& StorageCacheStats,
@@ -1256,6 +1260,7 @@ namespace builds_impl {
const DownloadOptions& Options)
{
ZEN_TRACE_CPU("DownloadFolder");
+ ZEN_SCOPED_LOG(InLog);
ProgressBar::SetLogOperationName(ProgressMode, "Download Folder");
@@ -1272,9 +1277,6 @@ namespace builds_impl {
auto EndProgress =
MakeGuard([&]() { ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::StepCount, TaskSteps::StepCount); });
- ZEN_ASSERT((!Options.PrimeCacheOnly) ||
- (Options.PrimeCacheOnly && (Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off)));
-
Stopwatch DownloadTimer;
ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::CheckState, TaskSteps::StepCount);
@@ -1306,7 +1308,7 @@ namespace builds_impl {
ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::CompareState, TaskSteps::StepCount);
- ChunkedFolderContent RemoteContent = GetRemoteContent(Output,
+ ChunkedFolderContent RemoteContent = GetRemoteContent(InLog,
Storage,
BuildId,
AllBuildParts,
@@ -1327,75 +1329,67 @@ namespace builds_impl {
BuildSaveState LocalState;
- if (!Options.PrimeCacheOnly)
+ if (IsDir(Path))
{
- if (IsDir(Path))
+ if (!ChunkController && !IsQuiet)
{
- if (!ChunkController && !IsQuiet)
- {
- ZEN_CONSOLE_INFO("Unspecified chunking algorithm, using default");
- ChunkController = CreateStandardChunkingController(StandardChunkingControllerSettings{});
- }
- std::unique_ptr<ChunkingCache> ChunkCache(CreateNullChunkingCache());
-
- LocalState = GetLocalContent(Workers,
- LocalFolderScanStats,
- ChunkingStats,
- Path,
- ZenStateFilePath(Path / ZenFolderName),
- *ChunkController,
- *ChunkCache);
-
- std::vector<std::filesystem::path> UntrackedPaths = GetNewPaths(LocalState.State.ChunkedContent.Paths, RemoteContent.Paths);
-
- BuildSaveState UntrackedLocalContent = GetLocalStateFromPaths(Workers,
- LocalFolderScanStats,
- ChunkingStats,
- Path,
- *ChunkController,
- *ChunkCache,
- UntrackedPaths);
-
- if (!UntrackedLocalContent.State.ChunkedContent.Paths.empty())
- {
- LocalState.State.ChunkedContent =
- MergeChunkedFolderContents(LocalState.State.ChunkedContent,
- std::vector<ChunkedFolderContent>{UntrackedLocalContent.State.ChunkedContent});
-
- // TODO: Helper
- LocalState.FolderState.Paths.insert(LocalState.FolderState.Paths.begin(),
- UntrackedLocalContent.FolderState.Paths.begin(),
- UntrackedLocalContent.FolderState.Paths.end());
- LocalState.FolderState.RawSizes.insert(LocalState.FolderState.RawSizes.begin(),
- UntrackedLocalContent.FolderState.RawSizes.begin(),
- UntrackedLocalContent.FolderState.RawSizes.end());
- LocalState.FolderState.Attributes.insert(LocalState.FolderState.Attributes.begin(),
- UntrackedLocalContent.FolderState.Attributes.begin(),
- UntrackedLocalContent.FolderState.Attributes.end());
- LocalState.FolderState.ModificationTicks.insert(LocalState.FolderState.ModificationTicks.begin(),
- UntrackedLocalContent.FolderState.ModificationTicks.begin(),
- UntrackedLocalContent.FolderState.ModificationTicks.end());
- }
+ ZEN_CONSOLE_INFO("Unspecified chunking algorithm, using default");
+ ChunkController = CreateStandardChunkingController(StandardChunkingControllerSettings{});
+ }
+ std::unique_ptr<ChunkingCache> ChunkCache(CreateNullChunkingCache());
- if (Options.AppendNewContent)
- {
- RemoteContent = ApplyChunkedContentOverlay(LocalState.State.ChunkedContent,
- RemoteContent,
- Options.IncludeWildcards,
- Options.ExcludeWildcards);
- }
-#if ZEN_BUILD_DEBUG
- ValidateChunkedFolderContent(RemoteContent,
- BlockDescriptions,
- LooseChunkHashes,
- Options.IncludeWildcards,
- Options.ExcludeWildcards);
-#endif // ZEN_BUILD_DEBUG
+ LocalState = GetLocalContent(Workers,
+ LocalFolderScanStats,
+ ChunkingStats,
+ Path,
+ ZenStateFilePath(Path / ZenFolderName),
+ *ChunkController,
+ *ChunkCache);
+
+ std::vector<std::filesystem::path> UntrackedPaths = GetNewPaths(LocalState.State.ChunkedContent.Paths, RemoteContent.Paths);
+
+ BuildSaveState UntrackedLocalContent =
+ GetLocalStateFromPaths(Workers, LocalFolderScanStats, ChunkingStats, Path, *ChunkController, *ChunkCache, UntrackedPaths);
+
+ if (!UntrackedLocalContent.State.ChunkedContent.Paths.empty())
+ {
+ LocalState.State.ChunkedContent =
+ MergeChunkedFolderContents(LocalState.State.ChunkedContent,
+ std::vector<ChunkedFolderContent>{UntrackedLocalContent.State.ChunkedContent});
+
+ // TODO: Helper
+ LocalState.FolderState.Paths.insert(LocalState.FolderState.Paths.begin(),
+ UntrackedLocalContent.FolderState.Paths.begin(),
+ UntrackedLocalContent.FolderState.Paths.end());
+ LocalState.FolderState.RawSizes.insert(LocalState.FolderState.RawSizes.begin(),
+ UntrackedLocalContent.FolderState.RawSizes.begin(),
+ UntrackedLocalContent.FolderState.RawSizes.end());
+ LocalState.FolderState.Attributes.insert(LocalState.FolderState.Attributes.begin(),
+ UntrackedLocalContent.FolderState.Attributes.begin(),
+ UntrackedLocalContent.FolderState.Attributes.end());
+ LocalState.FolderState.ModificationTicks.insert(LocalState.FolderState.ModificationTicks.begin(),
+ UntrackedLocalContent.FolderState.ModificationTicks.begin(),
+ UntrackedLocalContent.FolderState.ModificationTicks.end());
}
- else
+
+ if (Options.AppendNewContent)
{
- CreateDirectories(Path);
+ RemoteContent = ApplyChunkedContentOverlay(LocalState.State.ChunkedContent,
+ RemoteContent,
+ Options.IncludeWildcards,
+ Options.ExcludeWildcards);
}
+#if ZEN_BUILD_DEBUG
+ ValidateChunkedFolderContent(RemoteContent,
+ BlockDescriptions,
+ LooseChunkHashes,
+ Options.IncludeWildcards,
+ Options.ExcludeWildcards);
+#endif // ZEN_BUILD_DEBUG
+ }
+ else
+ {
+ CreateDirectories(Path);
}
if (AbortFlag)
{
@@ -1473,13 +1467,14 @@ namespace builds_impl {
if (!IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(Output, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs()));
+ ZEN_INFO("Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs()));
}
ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Download, TaskSteps::StepCount);
BuildsOperationUpdateFolder Updater(
- Output,
+ InLog,
+ Progress,
Storage,
AbortFlag,
PauseFlag,
@@ -1504,7 +1499,6 @@ namespace builds_impl {
.PreferredMultipartChunkSize = PreferredMultipartChunkSize,
.PartialBlockRequestMode = Options.PartialBlockRequestMode,
.WipeTargetFolder = Options.CleanTargetFolder,
- .PrimeCacheOnly = Options.PrimeCacheOnly,
.EnableOtherDownloadsScavenging = Options.EnableOtherDownloadsScavenging,
.EnableTargetFolderScavenging = Options.EnableTargetFolderScavenging || Options.AppendNewContent,
.ValidateCompletedSequences = Options.PostDownloadVerify,
@@ -1524,40 +1518,37 @@ namespace builds_impl {
VerifyFolderStatistics VerifyFolderStats;
if (!AbortFlag)
{
- if (!Options.PrimeCacheOnly)
+ AddDownloadedPath(Options.SystemRootDir,
+ BuildsDownloadInfo{.Selection = LocalState.State.Selection,
+ .LocalPath = Path,
+ .StateFilePath = ZenStateFilePath(Options.ZenFolderPath),
+ .Iso8601Date = DateTime::Now().ToIso8601()});
+
+ ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Verify, TaskSteps::StepCount);
+
+ VerifyFolder(Workers,
+ RemoteContent,
+ RemoteLookup,
+ Path,
+ Options.ExcludeFolders,
+ Options.PostDownloadVerify,
+ VerifyFolderStats);
+
+ Stopwatch WriteStateTimer;
+ CbObject StateObject = CreateBuildSaveStateObject(LocalState);
+
+ CreateDirectories(ZenStateFilePath(Options.ZenFolderPath).parent_path());
+ TemporaryFile::SafeWriteFile(ZenStateFilePath(Options.ZenFolderPath), StateObject.GetView());
+ if (!IsQuiet)
{
- AddDownloadedPath(Options.SystemRootDir,
- BuildsDownloadInfo{.Selection = LocalState.State.Selection,
- .LocalPath = Path,
- .StateFilePath = ZenStateFilePath(Options.ZenFolderPath),
- .Iso8601Date = DateTime::Now().ToIso8601()});
-
- ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Verify, TaskSteps::StepCount);
-
- VerifyFolder(Workers,
- RemoteContent,
- RemoteLookup,
- Path,
- Options.ExcludeFolders,
- Options.PostDownloadVerify,
- VerifyFolderStats);
-
- Stopwatch WriteStateTimer;
- CbObject StateObject = CreateBuildSaveStateObject(LocalState);
-
- CreateDirectories(ZenStateFilePath(Options.ZenFolderPath).parent_path());
- TemporaryFile::SafeWriteFile(ZenStateFilePath(Options.ZenFolderPath), StateObject.GetView());
- if (!IsQuiet)
- {
- ZEN_CONSOLE("Wrote local state in {}", NiceTimeSpanMs(WriteStateTimer.GetElapsedTimeMs()));
- }
+ ZEN_CONSOLE("Wrote local state in {}", NiceTimeSpanMs(WriteStateTimer.GetElapsedTimeMs()));
+ }
#if 0
ExtendableStringBuilder<1024> SB;
CompactBinaryToJson(StateObject, SB);
WriteFile(ZenStateFileJsonPath(Options.ZenFolderPath), IoBuffer(IoBuffer::Wrap, SB.Data(), SB.Size()));
#endif // 0
- }
const uint64_t DownloadCount = Updater.m_DownloadStats.DownloadedChunkCount.load() +
Updater.m_DownloadStats.DownloadedBlockCount.load() +
Updater.m_DownloadStats.DownloadedPartialBlockCount.load();
@@ -1647,26 +1638,6 @@ namespace builds_impl {
}
}
}
- if (Options.PrimeCacheOnly)
- {
- if (Storage.CacheStorage)
- {
- Storage.CacheStorage->Flush(5000, [](intptr_t Remaining) {
- if (!IsQuiet)
- {
- if (Remaining == 0)
- {
- ZEN_CONSOLE("Build cache upload complete");
- }
- else
- {
- ZEN_CONSOLE("Waiting for build cache to complete uploading. {} blobs remaining", Remaining);
- }
- }
- return !AbortFlag;
- });
- }
- }
ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Cleanup, TaskSteps::StepCount);
@@ -1784,6 +1755,7 @@ namespace builds_impl {
{
OptionalStructuredOutput->AddString("path"sv, fmt::format("{}", Path));
OptionalStructuredOutput->AddInteger("rawSize"sv, RawSize);
+ OptionalStructuredOutput->AddHash("rawHash"sv, RawHash);
switch (Platform)
{
case SourcePlatform::Windows:
@@ -2011,7 +1983,7 @@ namespace builds_impl {
} // namespace builds_impl
//////////////////////////////////////////////////////////////////////////////////////////////////////
-// BuildsCommand — Option-adding helpers
+// BuildsCommand - Option-adding helpers
//
void
@@ -2300,10 +2272,6 @@ BuildsCommand::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOptions*/)
{
ProgressMode = ProgressBar::Mode::Plain;
}
- else if (m_Verbose)
- {
- ProgressMode = ProgressBar::Mode::Plain;
- }
else if (IsQuiet)
{
ProgressMode = ProgressBar::Mode::Quiet;
@@ -2416,13 +2384,8 @@ BuildsCommand::CreateBuildStorage(BuildStorageBase::Statistics& StorageStats,
/*Hidden*/ false,
m_Verbose);
- BuildStorageResolveResult ResolveRes = ResolveBuildStorage(*CreateConsoleLogOutput(ProgressMode),
- ClientSettings,
- m_Host,
- m_OverrideHost,
- m_ZenCacheHost,
- ZenCacheResolveMode::All,
- m_Verbose);
+ BuildStorageResolveResult ResolveRes =
+ ResolveBuildStorage(ConsoleLog(), ClientSettings, m_Host, m_OverrideHost, m_ZenCacheHost, ZenCacheResolveMode::All, m_Verbose);
if (!ResolveRes.Cloud.Address.empty())
{
ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2;
@@ -2793,12 +2756,8 @@ BuildsCommand::ResolveZenFolderPath(const std::filesystem::path& DefaultPath)
}
EPartialBlockRequestMode
-BuildsCommand::ParseAllowPartialBlockRequests(bool PrimeCacheOnly, cxxopts::Options& SubOpts)
+BuildsCommand::ParseAllowPartialBlockRequests(cxxopts::Options& SubOpts)
{
- if (PrimeCacheOnly)
- {
- return EPartialBlockRequestMode::Off;
- }
EPartialBlockRequestMode Mode = PartialBlockRequestModeFromString(m_AllowPartialBlockRequests);
if (Mode == EPartialBlockRequestMode::Invalid)
{
@@ -3284,10 +3243,11 @@ BuildsUploadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
? CreateNullChunkingCache()
: CreateDiskChunkingCache(m_Parent.m_ChunkingCachePath, *ChunkController, 256u * 1024u);
- std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
std::vector<std::pair<Oid, std::string>> UploadedParts =
- UploadFolder(*Output,
+ UploadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
BuildId,
@@ -3314,7 +3274,7 @@ BuildsUploadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
{
for (const auto& Part : UploadedParts)
{
- ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, Part.first, Part.second);
+ ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, Part.first, Part.second);
}
}
}
@@ -3364,13 +3324,6 @@ BuildsDownloadSubCmd::BuildsDownloadSubCmd(BuildsCommand& Parent)
Parent.AddAppendNewContentOptions(Opts);
Parent.AddExcludeFolderOption(Opts);
- Opts.add_option("cache",
- "",
- "cache-prime-only",
- "Only download blobs missing in cache and upload to cache",
- cxxopts::value(m_PrimeCacheOnly),
- "<cacheprimeonly>");
-
Opts.add_option("", "l", "local-path", "Root file system folder for build", cxxopts::value(m_Path), "<local-path>");
Opts.add_option("", "", "build-id", "Build Id", cxxopts::value(m_BuildId), "<id>");
Opts.add_option("",
@@ -3470,36 +3423,16 @@ BuildsDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
m_BuildId,
/*RequireNamespace*/ true,
/*RequireBucket*/ true,
- /*BoostCacheBackgroundWorkerPool*/ m_PrimeCacheOnly,
+ /*BoostCacheBackgroundWorkerPool*/ false,
Auth,
Opts);
const Oid BuildId = m_Parent.ParseBuildId(m_BuildId, Opts);
- if (m_PostDownloadVerify && m_PrimeCacheOnly)
- {
- throw OptionParseException("'--cache-prime-only' conflicts with '--verify'", Opts.help());
- }
-
- if (m_Clean && m_PrimeCacheOnly)
- {
- ZEN_CONSOLE_WARN("Ignoring '--clean' option when '--cache-prime-only' is enabled");
- }
-
- if (m_Force && m_PrimeCacheOnly)
- {
- ZEN_CONSOLE_WARN("Ignoring '--force' option when '--cache-prime-only' is enabled");
- }
-
- if (m_Parent.m_AllowPartialBlockRequests != "false" && m_PrimeCacheOnly)
- {
- ZEN_CONSOLE_WARN("Ignoring '--allow-partial-block-requests' option when '--cache-prime-only' is enabled");
- }
-
std::vector<Oid> BuildPartIds = m_Parent.ParseBuildPartIds(m_BuildPartIds, Opts);
std::vector<std::string> BuildPartNames = m_Parent.ParseBuildPartNames(m_BuildPartNames, Opts);
- EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(m_PrimeCacheOnly, Opts);
+ EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(Opts);
if (m_Parent.m_AppendNewContent && m_Clean)
{
@@ -3510,10 +3443,11 @@ BuildsDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
std::vector<std::string> ExcludeExtensions = DefaultExcludeExtensions;
m_Parent.ParseExcludeFolderAndExtension(ExcludeFolders, ExcludeExtensions);
- std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
DownloadFolder(
- *Output,
+ ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -3528,7 +3462,6 @@ BuildsDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = m_Clean,
.PostDownloadVerify = m_PostDownloadVerify,
- .PrimeCacheOnly = m_PrimeCacheOnly,
.EnableOtherDownloadsScavenging = m_EnableScavenging && !m_Force,
.EnableTargetFolderScavenging = !m_Force,
.AllowFileClone = m_AllowFileClone,
@@ -3898,9 +3831,10 @@ BuildsPrimeCacheSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
ProgressBar::SetLogOperationName(ProgressMode, "Prime Cache");
- std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
- BuildsOperationPrimeCache PrimeOp(*Output,
+ BuildsOperationPrimeCache PrimeOp(ConsoleLog(),
+ *Progress,
Storage,
AbortFlag,
PauseFlag,
@@ -4056,9 +3990,9 @@ BuildsValidatePartSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
const Oid BuildPartId = m_BuildPartName.empty() ? Oid::Zero : m_Parent.ParseBuildPartId(m_BuildPartId, Opts);
- std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
- ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName);
+ ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName);
if (AbortFlag)
{
@@ -4131,7 +4065,7 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
});
- EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(false, Opts);
+ EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(Opts);
BuildStorageBase::Statistics StorageStats;
BuildStorageCache::Statistics StorageCacheStats;
@@ -4202,9 +4136,10 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
? CreateNullChunkingCache()
: CreateDiskChunkingCache(m_Parent.m_ChunkingCachePath, *ChunkController, 256u * 1024u);
- std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
- UploadFolder(*Output,
+ UploadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
BuildId,
@@ -4231,7 +4166,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
{
ZEN_CONSOLE("Upload Build {}, Part {} ({}) from '{}' with chunking cache", m_BuildId, BuildPartId, m_BuildPartName, m_Path);
- UploadFolder(*Output,
+ UploadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
Oid::NewOid(),
@@ -4256,7 +4192,7 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
}
- ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName);
+ ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName);
if (!m_Parent.m_IncludeWildcard.empty() || !m_Parent.m_ExcludeWildcard.empty())
{
@@ -4268,7 +4204,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
std::vector<std::string> ExcludeWildcards;
m_Parent.ParseFileFilters(IncludeWildcards, ExcludeWildcards);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4283,7 +4220,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = true,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = false,
.AllowFileClone = m_AllowFileClone,
@@ -4300,7 +4236,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
BuildPartId,
m_BuildPartName,
DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4315,7 +4252,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = true,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone,
@@ -4328,7 +4264,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nDownload Full Build {}, Part {} ({}) to '{}'", BuildId, BuildPartId, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4343,7 +4280,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone,
@@ -4357,7 +4293,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}'", BuildId, BuildPartId, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4373,7 +4310,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = true,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = false,
.AllowFileClone = m_AllowFileClone});
@@ -4383,7 +4319,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nRe-download Build {}, Part {} ({}) to '{}' (identical target)", BuildId, BuildPartId, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4399,7 +4336,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4501,7 +4437,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
ScrambleDir(DownloadPath);
ZEN_CONSOLE("\nRe-download Build {}, Part {} ({}) to '{}' (scrambled target)", BuildId, BuildPartId, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4517,7 +4454,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4538,7 +4474,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
ZEN_CONSOLE("\nUpload scrambled Build {}, Part {} ({})\n{}\n", BuildId2, BuildPartId2, m_BuildPartName, SB.ToView());
}
- UploadFolder(*Output,
+ UploadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
BuildId2,
@@ -4562,10 +4499,11 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
throw std::runtime_error("Test aborted. (Upload scrambled)");
}
- ValidateBuildPart(*Output, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName);
+ ValidateBuildPart(ConsoleLog(), *Progress, Workers, *Storage.BuildStorage, BuildId, BuildPartId, m_BuildPartName);
ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (original)", BuildId, BuildPartId, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4581,7 +4519,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4591,7 +4528,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (scrambled)", BuildId2, BuildPartId2, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4606,7 +4544,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4616,7 +4553,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nRe-download Build {}, Part {} ({}) to '{}' (scrambled)", BuildId2, BuildPartId2, m_BuildPartName, DownloadPath);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4631,7 +4569,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4641,7 +4578,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (original)", BuildId, BuildPartId, m_BuildPartName, DownloadPath2);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4656,7 +4594,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4666,7 +4603,8 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
}
ZEN_CONSOLE("\nDownload Build {}, Part {} ({}) to '{}' (original)", BuildId, BuildPartId, m_BuildPartName, DownloadPath3);
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4681,7 +4619,6 @@ BuildsTestSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = false,
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = true,
.AllowFileClone = m_AllowFileClone});
@@ -4741,7 +4678,7 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
m_Parent.ResolveZenFolderPath(m_Path / ZenFolderName);
- EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(false, Opts);
+ EPartialBlockRequestMode PartialBlockRequestMode = m_Parent.ParseAllowPartialBlockRequests(Opts);
BuildStorageBase::Statistics StorageStats;
BuildStorageCache::Statistics StorageCacheStats;
@@ -4758,7 +4695,7 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
Auth,
Opts);
- std::unique_ptr<OperationLogOutput> Output(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
Stopwatch Timer;
for (const std::string& BuildIdString : m_BuildIds)
@@ -4768,7 +4705,8 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
{
throw OptionParseException(fmt::format("'--build-id' ('{}') is malformed", BuildIdString), Opts.help());
}
- DownloadFolder(*Output,
+ DownloadFolder(ConsoleLog(),
+ *Progress,
Workers,
Storage,
StorageCacheStats,
@@ -4783,7 +4721,6 @@ BuildsMultiTestDownloadSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
.PartialBlockRequestMode = PartialBlockRequestMode,
.CleanTargetFolder = BuildIdString == m_BuildIds.front(),
.PostDownloadVerify = true,
- .PrimeCacheOnly = false,
.EnableOtherDownloadsScavenging = m_EnableScavenging,
.EnableTargetFolderScavenging = false,
.AllowFileClone = m_AllowFileClone});
diff --git a/src/zen/cmds/builds_cmd.h b/src/zen/cmds/builds_cmd.h
index 7ef71e176..ef7500fd6 100644
--- a/src/zen/cmds/builds_cmd.h
+++ b/src/zen/cmds/builds_cmd.h
@@ -94,7 +94,6 @@ private:
bool m_EnableScavenging = true;
std::filesystem::path m_DownloadSpecPath;
bool m_UploadToZenCache = true;
- bool m_PrimeCacheOnly = false;
bool m_AllowFileClone = true;
};
@@ -280,7 +279,7 @@ public:
cxxopts::Options& SubOpts);
void ParsePath(std::filesystem::path& Path, cxxopts::Options& SubOpts);
IoHash ParseBlobHash(const std::string& BlobHashStr, cxxopts::Options& SubOpts);
- EPartialBlockRequestMode ParseAllowPartialBlockRequests(bool PrimeCacheOnly, cxxopts::Options& SubOpts);
+ EPartialBlockRequestMode ParseAllowPartialBlockRequests(cxxopts::Options& SubOpts);
void ParseZenProcessId(int& ZenProcessId);
void ParseFileFilters(std::vector<std::string>& OutIncludeWildcards, std::vector<std::string>& OutExcludeWildcards);
void ParseExcludeFolderAndExtension(std::vector<std::string>& OutExcludeFolders, std::vector<std::string>& OutExcludeExtensions);
diff --git a/src/zen/cmds/compute_cmd.cpp b/src/zen/cmds/compute_cmd.cpp
new file mode 100644
index 000000000..01166cb0e
--- /dev/null
+++ b/src/zen/cmds/compute_cmd.cpp
@@ -0,0 +1,96 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "compute_cmd.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/logging.h>
+# include <zenhttp/httpclient.h>
+
+using namespace std::literals;
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+// ComputeRecordStartSubCmd
+
+ComputeRecordStartSubCmd::ComputeRecordStartSubCmd() : ZenSubCmdBase("record-start", "Start recording compute actions")
+{
+ SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
+}
+
+void
+ComputeRecordStartSubCmd::Run(const ZenCliOptions& GlobalOptions)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName);
+ if (m_HostName.empty())
+ {
+ throw OptionParseException("Unable to resolve server specification", SubOptions().help());
+ }
+
+ HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName);
+ if (HttpClient::Response Response = Http.Post("/compute/record/start"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}))
+ {
+ CbObject Obj = Response.AsObject();
+ std::string_view Path = Obj["path"sv].AsString();
+ ZEN_CONSOLE("recording started: " ZEN_BRIGHT_GREEN("{}"), Path);
+ }
+ else
+ {
+ Response.ThrowError("Failed to start recording");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// ComputeRecordStopSubCmd
+
+ComputeRecordStopSubCmd::ComputeRecordStopSubCmd() : ZenSubCmdBase("record-stop", "Stop recording compute actions")
+{
+ SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
+}
+
+void
+ComputeRecordStopSubCmd::Run(const ZenCliOptions& GlobalOptions)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName);
+ if (m_HostName.empty())
+ {
+ throw OptionParseException("Unable to resolve server specification", SubOptions().help());
+ }
+
+ HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName);
+ if (HttpClient::Response Response = Http.Post("/compute/record/stop"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}))
+ {
+ CbObject Obj = Response.AsObject();
+ std::string_view Path = Obj["path"sv].AsString();
+ ZEN_CONSOLE("recording stopped: " ZEN_BRIGHT_GREEN("{}"), Path);
+ }
+ else
+ {
+ Response.ThrowError("Failed to stop recording");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// ComputeCommand
+
+ComputeCommand::ComputeCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), "");
+ m_Options.parse_positional({"subcommand"});
+
+ AddSubCommand(m_RecordStartSubCmd);
+ AddSubCommand(m_RecordStopSubCmd);
+}
+
+ComputeCommand::~ComputeCommand() = default;
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zen/cmds/compute_cmd.h b/src/zen/cmds/compute_cmd.h
new file mode 100644
index 000000000..b26f639c4
--- /dev/null
+++ b/src/zen/cmds/compute_cmd.h
@@ -0,0 +1,53 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+#include <string>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen {
+
+class ComputeRecordStartSubCmd : public ZenSubCmdBase
+{
+public:
+ ComputeRecordStartSubCmd();
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ std::string m_HostName;
+};
+
+class ComputeRecordStopSubCmd : public ZenSubCmdBase
+{
+public:
+ ComputeRecordStopSubCmd();
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ std::string m_HostName;
+};
+
+class ComputeCommand : public ZenCmdWithSubCommands
+{
+public:
+ static constexpr char Name[] = "compute";
+ static constexpr char Description[] = "Compute service operations";
+
+ ComputeCommand();
+ ~ComputeCommand();
+
+ cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{Name, Description};
+ std::string m_SubCommand;
+ ComputeRecordStartSubCmd m_RecordStartSubCmd;
+ ComputeRecordStopSubCmd m_RecordStopSubCmd;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp
index 30e860a3f..dab53f13c 100644
--- a/src/zen/cmds/exec_cmd.cpp
+++ b/src/zen/cmds/exec_cmd.cpp
@@ -23,6 +23,8 @@
#include <zenhttp/httpclient.h>
#include <zenhttp/packageformat.h>
+#include "../progressbar.h"
+
#include <EASTL/hash_map.h>
#include <EASTL/hash_set.h>
#include <EASTL/map.h>
@@ -114,7 +116,7 @@ namespace {
} // namespace
//////////////////////////////////////////////////////////////////////////
-// ExecSessionConfig — read-only configuration for a session run
+// ExecSessionConfig - read-only configuration for a session run
struct ExecSessionConfig
{
@@ -124,17 +126,18 @@ struct ExecSessionConfig
std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce
std::string_view OrchestratorUrl;
const std::filesystem::path& OutputPath;
- int Offset = 0;
- int Stride = 1;
- int Limit = 0;
- bool Verbose = false;
- bool Quiet = false;
- bool DumpActions = false;
- bool Binary = false;
+ int Offset = 0;
+ int Stride = 1;
+ int Limit = 0;
+ bool Verbose = false;
+ bool Quiet = false;
+ bool DumpActions = false;
+ bool Binary = false;
+ ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty;
};
//////////////////////////////////////////////////////////////////////////
-// ExecSessionRunner — owns per-run state, drives the session lifecycle
+// ExecSessionRunner - owns per-run state, drives the session lifecycle
class ExecSessionRunner
{
@@ -345,8 +348,6 @@ ExecSessionRunner::DrainCompletedJobs()
}
m_PendingJobs.Remove(CompleteLsn);
-
- ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize());
}
}
}
@@ -897,17 +898,20 @@ ExecSessionRunner::Run()
// Then submit work items
- int FailedWorkCounter = 0;
- size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount();
- int SubmittedWorkItems = 0;
+ std::atomic<int> FailedWorkCounter{0};
+ std::atomic<size_t> RemainingWorkItems{m_Config.RecordingReader.GetActionCount()};
+ std::atomic<int> SubmittedWorkItems{0};
+ size_t TotalWorkItems = RemainingWorkItems.load();
- ZEN_CONSOLE("submitting {} work items", RemainingWorkItems);
+ ProgressBar SubmitProgress(m_Config.ProgressMode, "Submit");
+ SubmitProgress.UpdateState({.Task = "Submitting work items", .TotalCount = TotalWorkItems, .RemainingCount = RemainingWorkItems.load()},
+ false);
int OffsetCounter = m_Config.Offset;
int StrideCounter = m_Config.Stride;
auto ShouldSchedule = [&]() -> bool {
- if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit)
+ if (m_Config.Limit && SubmittedWorkItems.load() >= m_Config.Limit)
{
// Limit reached, ignore
@@ -1005,17 +1009,14 @@ ExecSessionRunner::Run()
{
const int32_t LsnField = EnqueueResult.Lsn;
- --RemainingWorkItems;
- ++SubmittedWorkItems;
+ size_t Remaining = --RemainingWorkItems;
+ int Submitted = ++SubmittedWorkItems;
- if (!m_Config.Quiet)
- {
- ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining",
- SubmittedWorkItems,
- LsnField,
- NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
- RemainingWorkItems);
- }
+ SubmitProgress.UpdateState({.Task = "Submitting work items",
+ .Details = fmt::format("#{} LSN {}", Submitted, LsnField),
+ .TotalCount = TotalWorkItems,
+ .RemainingCount = Remaining},
+ false);
if (!m_Config.OutputPath.empty())
{
@@ -1055,22 +1056,36 @@ ExecSessionRunner::Run()
},
TargetParallelism);
+ SubmitProgress.Finish();
+
// Wait until all pending work is complete
+ size_t TotalPendingJobs = m_PendingJobs.GetSize();
+
+ ProgressBar CompletionProgress(m_Config.ProgressMode, "Execute");
+
while (!m_PendingJobs.IsEmpty())
{
- // TODO: improve this logic
- zen::Sleep(500);
+ size_t PendingCount = m_PendingJobs.GetSize();
+ CompletionProgress.UpdateState({.Task = "Executing work items",
+ .Details = fmt::format("{} completed, {} remaining", TotalPendingJobs - PendingCount, PendingCount),
+ .TotalCount = TotalPendingJobs,
+ .RemainingCount = PendingCount},
+ false);
+
+ zen::Sleep(GetUpdateDelayMS(m_Config.ProgressMode));
DrainCompletedJobs();
SendOrchestratorHeartbeat();
}
+ CompletionProgress.Finish();
+
// Write summary files
WriteSummaryFiles();
- if (FailedWorkCounter)
+ if (FailedWorkCounter.load())
{
return 1;
}
@@ -1119,6 +1134,7 @@ ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent)
{
+ m_SubOptions.add_option("managed", "", "managed", "Use managed local runner (if supported)", cxxopts::value(m_Managed), "<bool>");
}
void
@@ -1130,7 +1146,16 @@ ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
zen::compute::ComputeServiceSession ComputeSession(Resolver);
std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
- ComputeSession.AddLocalRunner(Resolver, TempPath);
+ if (m_Managed)
+ {
+ ZEN_CONSOLE_INFO("using managed local runner");
+ ComputeSession.AddManagedLocalRunner(Resolver, TempPath);
+ }
+ else
+ {
+ ZEN_CONSOLE_INFO("using local runner");
+ ComputeSession.AddLocalRunner(Resolver, TempPath);
+ }
Stopwatch ExecTimer;
int ReturnValue = m_Parent.RunSession(ComputeSession);
@@ -1413,6 +1438,12 @@ ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions)
int
ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl)
{
+ ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty;
+ if (m_QuietLogging)
+ {
+ ProgressMode = ProgressBar::Mode::Quiet;
+ }
+
ExecSessionConfig Config{
.Resolver = *m_ChunkResolver,
.RecordingReader = *m_RecordingReader,
@@ -1427,6 +1458,7 @@ ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std
.Quiet = m_QuietLogging,
.DumpActions = m_DumpActions,
.Binary = m_Binary,
+ .ProgressMode = ProgressMode,
};
ExecSessionRunner Runner(ComputeSession, Config);
diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h
index c55412780..a0bf201a1 100644
--- a/src/zen/cmds/exec_cmd.h
+++ b/src/zen/cmds/exec_cmd.h
@@ -61,6 +61,7 @@ public:
private:
ExecCommand& m_Parent;
+ bool m_Managed = false;
};
class ExecBeaconSubCmd : public ZenSubCmdBase
diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp
index d31c34fd0..7f94bf2df 100644
--- a/src/zen/cmds/projectstore_cmd.cpp
+++ b/src/zen/cmds/projectstore_cmd.cpp
@@ -24,10 +24,10 @@
#include <zenremotestore/builds/buildstorageutil.h>
#include <zenremotestore/builds/jupiterbuildstorage.h>
#include <zenremotestore/jupiter/jupiterhost.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenremotestore/projectstore/projectstoreoperations.h>
#include <zenremotestore/projectstore/remoteprojectstore.h>
#include <zenremotestore/transferthreadworkers.h>
+#include <zenutil/progress.h>
#include <zenutil/workerpools.h>
#include "../progressbar.h"
@@ -2500,10 +2500,6 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
{
ProgressMode = ProgressBar::Mode::Plain;
}
- else if (m_Verbose)
- {
- ProgressMode = ProgressBar::Mode::Plain;
- }
else if (m_Quiet)
{
ProgressMode = ProgressBar::Mode::Quiet;
@@ -2565,7 +2561,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
m_BoostWorkerMemory = true;
}
- std::unique_ptr<OperationLogOutput> OperationLogOutput(CreateConsoleLogOutput(ProgressMode));
+ std::unique_ptr<ProgressBase> Progress(CreateConsoleProgress(ProgressMode));
TransferThreadWorkers Workers(m_BoostWorkerCount, /*SingleThreaded*/ false);
if (!m_Quiet)
@@ -2594,7 +2590,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
/*Hidden*/ false,
m_Verbose);
- BuildStorageResolveResult ResolveRes = ResolveBuildStorage(*OperationLogOutput,
+ BuildStorageResolveResult ResolveRes = ResolveBuildStorage(ConsoleLog(),
ClientSettings,
m_Host,
m_OverrideHost,
@@ -2677,7 +2673,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
}
ProjectStoreOperationOplogState State(
- *OperationLogOutput,
+ ConsoleLog(),
Storage,
BuildId,
{.IsQuiet = m_Quiet, .IsVerbose = m_Verbose, .ForceDownload = m_ForceDownload, .TempFolderPath = StorageTempPath});
@@ -2704,7 +2700,8 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a
}
std::atomic<bool> PauseFlag;
- ProjectStoreOperationDownloadAttachments Op(*OperationLogOutput,
+ ProjectStoreOperationDownloadAttachments Op(ConsoleLog(),
+ *Progress,
Storage,
AbortFlag,
PauseFlag,
diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h
index 1ba98b39e..41db36139 100644
--- a/src/zen/cmds/projectstore_cmd.h
+++ b/src/zen/cmds/projectstore_cmd.h
@@ -217,7 +217,7 @@ class SnapshotOplogCommand : public ProjectStoreCommand
{
public:
static constexpr char Name[] = "oplog-snapshot";
- static constexpr char Description[] = "Snapshot project store oplog";
+ static constexpr char Description[] = "Copy oplog's loose files on disk into zenserver";
SnapshotOplogCommand();
~SnapshotOplogCommand();
diff --git a/src/zen/cmds/service_cmd.cpp b/src/zen/cmds/service_cmd.cpp
index 37baf5483..5e284cbdf 100644
--- a/src/zen/cmds/service_cmd.cpp
+++ b/src/zen/cmds/service_cmd.cpp
@@ -320,7 +320,7 @@ ServiceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
throw OptionParseException("'verb' option is required", m_Options.help());
}
- // Parse subcommand permissively — forward unrecognised options to the parent parser.
+ // Parse subcommand permissively - forward unrecognised options to the parent parser.
std::vector<std::string> SubUnmatched;
if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
{
diff --git a/src/zen/cmds/ui_cmd.cpp b/src/zen/cmds/ui_cmd.cpp
index 4846b4d18..28ab6c45c 100644
--- a/src/zen/cmds/ui_cmd.cpp
+++ b/src/zen/cmds/ui_cmd.cpp
@@ -162,7 +162,7 @@ UiCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
Labels.push_back(fmt::format("(all {} instances)", Servers.size()));
const int32_t Cols = static_cast<int32_t>(TuiConsoleColumns());
- constexpr int32_t kIndicator = 3; // " ▶ " or " " prefix
+ constexpr int32_t kIndicator = 3; // " > " or " " prefix
constexpr int32_t kSeparator = 2; // " " before cmdline
constexpr int32_t kEllipsis = 3; // "..."
diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp
index 10f5ad8e1..c027f0d67 100644
--- a/src/zen/cmds/wipe_cmd.cpp
+++ b/src/zen/cmds/wipe_cmd.cpp
@@ -4,6 +4,7 @@
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
+#include <zencore/iohash.h>
#include <zencore/logging.h>
#include <zencore/parallelwork.h>
#include <zencore/string.h>
@@ -548,7 +549,7 @@ WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
Quiet = m_Quiet;
IsVerbose = m_Verbose;
- ProgressMode = (IsVerbose || m_PlainProgress) ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty;
+ ProgressMode = m_PlainProgress ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty;
BoostWorkerThreads = m_BoostWorkerThreads;
MakeSafeAbsolutePathInPlace(m_Directory);
diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp
index 9e49b464e..1ae2d58fa 100644
--- a/src/zen/cmds/workspaces_cmd.cpp
+++ b/src/zen/cmds/workspaces_cmd.cpp
@@ -137,7 +137,7 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
throw OptionParseException("'verb' option is required", m_Options.help());
}
- // Parse subcommand permissively — forward unrecognised options to the parent parser.
+ // Parse subcommand permissively - forward unrecognised options to the parent parser.
std::vector<std::string> SubUnmatched;
if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
{
@@ -403,7 +403,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char**
throw OptionParseException("'verb' option is required", m_Options.help());
}
- // Parse subcommand permissively — forward unrecognised options to the parent parser.
+ // Parse subcommand permissively - forward unrecognised options to the parent parser.
std::vector<std::string> SubUnmatched;
if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
{
diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp
index 6581cd116..780b08707 100644
--- a/src/zen/progressbar.cpp
+++ b/src/zen/progressbar.cpp
@@ -5,10 +5,15 @@
#include "progressbar.h"
+#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/windows.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenutil/consoletui.h>
+#include <zenutil/progress.h>
+
+#if !ZEN_PLATFORM_WINDOWS
+# include <csignal>
+#endif
ZEN_THIRD_PARTY_INCLUDES_START
#include <gsl/gsl-lite.hpp>
@@ -18,6 +23,87 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
+// Global tracking for scroll region cleanup on abnormal termination (Ctrl+C etc.)
+// Only one ProgressBar can own a scroll region at a time.
+static std::atomic<ProgressBar*> g_ActiveScrollRegionOwner{nullptr};
+static std::atomic<uint32_t> g_ActiveScrollRegionRows{0};
+
+static void
+ResetScrollRegionRaw()
+{
+ // Signal-safe: emit raw escape sequences to restore terminal state.
+ // These are async-signal-safe on POSIX (write()) and safe in console
+ // ctrl handlers on Windows (WriteConsole is allowed).
+ uint32_t Rows = g_ActiveScrollRegionRows.load(std::memory_order_acquire);
+ if (Rows >= 3)
+ {
+ // Move to status line, erase it, reset scroll region, move cursor to end of content
+ TuiMoveCursor(Rows, 1);
+ TuiEraseLine();
+ TuiResetScrollRegion();
+ TuiMoveCursor(Rows - 1, 1);
+ }
+ else
+ {
+ TuiResetScrollRegion();
+ }
+ TuiShowCursor(true);
+ TuiFlush();
+}
+
+#if ZEN_PLATFORM_WINDOWS
+static BOOL WINAPI
+ScrollRegionCtrlHandler(DWORD CtrlType)
+{
+ if (CtrlType == CTRL_C_EVENT || CtrlType == CTRL_BREAK_EVENT)
+ {
+ ResetScrollRegionRaw();
+ }
+ // Return FALSE so the default handler (process termination) still runs
+ return FALSE;
+}
+#else
+static struct sigaction s_PrevSigIntAction;
+static struct sigaction s_PrevSigTermAction;
+
+static void
+ScrollRegionSignalHandler(int Signal)
+{
+ ResetScrollRegionRaw();
+
+ // Re-raise with the previous handler
+ struct sigaction* PrevAction = (Signal == SIGINT) ? &s_PrevSigIntAction : &s_PrevSigTermAction;
+ sigaction(Signal, PrevAction, nullptr);
+ raise(Signal);
+}
+#endif
+
+static void
+InstallScrollRegionCleanupHandler()
+{
+#if ZEN_PLATFORM_WINDOWS
+ SetConsoleCtrlHandler(ScrollRegionCtrlHandler, TRUE);
+#else
+ struct sigaction Action = {};
+ Action.sa_handler = ScrollRegionSignalHandler;
+ Action.sa_flags = SA_RESETHAND; // one-shot
+ sigemptyset(&Action.sa_mask);
+ sigaction(SIGINT, &Action, &s_PrevSigIntAction);
+ sigaction(SIGTERM, &Action, &s_PrevSigTermAction);
+#endif
+}
+
+static void
+RemoveScrollRegionCleanupHandler()
+{
+#if ZEN_PLATFORM_WINDOWS
+ SetConsoleCtrlHandler(ScrollRegionCtrlHandler, FALSE);
+#else
+ sigaction(SIGINT, &s_PrevSigIntAction, nullptr);
+ sigaction(SIGTERM, &s_PrevSigTermAction, nullptr);
+#endif
+}
+
#if ZEN_PLATFORM_WINDOWS
static HANDLE
GetConsoleHandle()
@@ -128,12 +214,18 @@ ProgressBar::ProgressBar(Mode InMode, std::string_view InSubTask)
{
PushLogOperation(InMode, m_SubTask);
}
+
+ if (m_Mode == Mode::Pretty)
+ {
+ SetupScrollRegion();
+ }
}
ProgressBar::~ProgressBar()
{
try
{
+ TeardownScrollRegion();
ForceLinebreak();
if (!m_SubTask.empty())
{
@@ -147,6 +239,87 @@ ProgressBar::~ProgressBar()
}
void
+ProgressBar::SetupScrollRegion()
+{
+ // Only one scroll region owner at a time; nested bars fall back to the inline \r path.
+ if (g_ActiveScrollRegionOwner.load(std::memory_order_acquire) != nullptr)
+ {
+ return;
+ }
+
+ uint32_t Rows = TuiConsoleRows(0);
+ if (Rows < 3)
+ {
+ return;
+ }
+
+ TuiEnableOutput();
+
+ // Ensure cursor is not on the last row before we install the region.
+ // Print a newline to push content up if needed, then set the region.
+ OutputToConsoleRaw("\n");
+ TuiSetScrollRegion(1, Rows - 1);
+
+ // Move cursor into the scroll region so normal output stays there
+ TuiMoveCursor(Rows - 1, 1);
+
+ m_ScrollRegionActive = true;
+ m_ScrollRegionRows = Rows;
+
+ g_ActiveScrollRegionRows.store(Rows, std::memory_order_release);
+ g_ActiveScrollRegionOwner.store(this, std::memory_order_release);
+ InstallScrollRegionCleanupHandler();
+}
+
+void
+ProgressBar::TeardownScrollRegion()
+{
+ if (!m_ScrollRegionActive)
+ {
+ return;
+ }
+ m_ScrollRegionActive = false;
+
+ RemoveScrollRegionCleanupHandler();
+ g_ActiveScrollRegionOwner.store(nullptr, std::memory_order_release);
+ g_ActiveScrollRegionRows.store(0, std::memory_order_release);
+
+ // Emit all teardown escape sequences as a single atomic write
+ ExtendableStringBuilder<128> Buf;
+ Buf << fmt::format("\x1b[{};1H", m_ScrollRegionRows) // move to status line
+ << "\x1b[2K" // erase it
+ << "\x1b[r" // reset scroll region
+ << fmt::format("\x1b[{};1H", m_ScrollRegionRows - 1); // move to end of content
+ OutputToConsoleRaw(Buf);
+ TuiFlush();
+}
+
+void
+ProgressBar::RenderStatusLine(std::string_view Line)
+{
+ // Handle terminal resizes by re-querying row count
+ uint32_t CurrentRows = TuiConsoleRows(0);
+ if (CurrentRows >= 3 && CurrentRows != m_ScrollRegionRows)
+ {
+ // Terminal was resized - reinstall scroll region
+ TuiSetScrollRegion(1, CurrentRows - 1);
+ m_ScrollRegionRows = CurrentRows;
+ }
+
+ // Build the entire escape sequence as a single string so the console write
+ // is atomic and log output from other threads cannot interleave.
+ ExtendableStringBuilder<512> Buf;
+ Buf << "\x1b"
+ "7" // ESC 7 - save cursor
+ << fmt::format("\x1b[{};1H", m_ScrollRegionRows) // move to bottom row
+ << "\x1b[2K" // erase entire line
+ << Line // progress bar content
+ << "\x1b"
+ "8"; // ESC 8 - restore cursor
+ OutputToConsoleRaw(Buf);
+}
+
+void
ProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
{
ZEN_ASSERT(NewState.TotalCount >= NewState.RemainingCount);
@@ -226,7 +399,7 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
ExtendableStringBuilder<256> OutputBuilder;
- OutputBuilder << "\r" << Task << " " << PercentString;
+ OutputBuilder << Task << " " << PercentString;
if (OutputBuilder.Size() + 1 < ConsoleColumns)
{
size_t RemainingSpace = ConsoleColumns - (OutputBuilder.Size() + 1);
@@ -257,34 +430,43 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
}
}
- std::string_view Output = OutputBuilder.ToView();
- std::string::size_type EraseLength = m_LastOutputLength > Output.length() ? (m_LastOutputLength - Output.length()) : 0;
+ if (m_ScrollRegionActive)
+ {
+ // Render on the pinned bottom status line
+ RenderStatusLine(OutputBuilder.ToView());
+ }
+ else
+ {
+ // Fallback: inline \r-based overwrite (terminal too small for scroll region)
+ std::string_view Output = OutputBuilder.ToView();
+ std::string::size_type EraseLength =
+ m_LastOutputLength > (Output.length() + 1) ? (m_LastOutputLength - Output.length() - 1) : 0;
+ ExtendableStringBuilder<256> LineToPrint;
- ExtendableStringBuilder<256> LineToPrint;
+ if (Output.length() + 1 + EraseLength >= ConsoleColumns)
+ {
+ if (m_LastOutputLength > 0)
+ {
+ LineToPrint << "\n";
+ }
+ LineToPrint << Output;
+ DoLinebreak = true;
+ }
+ else
+ {
+ LineToPrint << "\r" << Output << std::string(EraseLength, ' ');
+ }
- if (Output.length() + EraseLength >= ConsoleColumns)
- {
- if (m_LastOutputLength > 0)
+ if (DoLinebreak)
{
LineToPrint << "\n";
}
- LineToPrint << Output.substr(1);
- DoLinebreak = true;
- }
- else
- {
- LineToPrint << Output << std::string(EraseLength, ' ');
- }
- if (DoLinebreak)
- {
- LineToPrint << "\n";
+ OutputToConsoleRaw(LineToPrint);
+ m_LastOutputLength = DoLinebreak ? 0 : (Output.length() + 1); // +1 for \r prefix
}
- OutputToConsoleRaw(LineToPrint);
-
- m_LastOutputLength = DoLinebreak ? 0 : Output.length();
- m_State = NewState;
+ m_State = NewState;
}
else if (m_Mode == Mode::Log)
{
@@ -329,6 +511,8 @@ ProgressBar::ForceLinebreak()
void
ProgressBar::Finish()
{
+ TeardownScrollRegion();
+
if (m_LastOutputLength > 0 || m_State.RemainingCount > 0)
{
State NewState = m_State;
@@ -353,64 +537,29 @@ ProgressBar::HasActiveTask() const
return !m_State.Task.empty();
}
-class ConsoleOpLogProgressBar : public OperationLogOutput::ProgressBar
-{
-public:
- ConsoleOpLogProgressBar(zen::ProgressBar::Mode InMode, std::string_view InSubTask) : m_Inner(InMode, InSubTask) {}
-
- virtual void UpdateState(const State& NewState, bool DoLinebreak)
- {
- zen::ProgressBar::State State = {.Task = NewState.Task,
- .Details = NewState.Details,
- .TotalCount = NewState.TotalCount,
- .RemainingCount = NewState.RemainingCount,
- .Status = ConvertStatus(NewState.Status)};
- m_Inner.UpdateState(State, DoLinebreak);
- }
- virtual void Finish() { m_Inner.Finish(); }
-
-private:
- zen::ProgressBar::State::EStatus ConvertStatus(State::EStatus Status)
- {
- switch (Status)
- {
- case State::EStatus::Running:
- return zen::ProgressBar::State::EStatus::Running;
- case State::EStatus::Aborted:
- return zen::ProgressBar::State::EStatus::Aborted;
- case State::EStatus::Paused:
- return zen::ProgressBar::State::EStatus::Paused;
- default:
- return (zen::ProgressBar::State::EStatus)Status;
- }
- }
- zen::ProgressBar m_Inner;
-};
-
-class ConsoleOpLogOutput : public OperationLogOutput
+class ConsoleOpLogOutput : public ProgressBase
{
public:
ConsoleOpLogOutput(zen::ProgressBar::Mode InMode) : m_Mode(InMode) {}
- virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override
- {
- logging::EmitConsoleLogMessage(Point, Args);
- }
virtual void SetLogOperationName(std::string_view Name) override { zen::ProgressBar::SetLogOperationName(m_Mode, Name); }
virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override
{
zen::ProgressBar::SetLogOperationProgress(m_Mode, StepIndex, StepCount);
}
- virtual uint32_t GetProgressUpdateDelayMS() override { return GetUpdateDelayMS(m_Mode); }
+ virtual uint32_t GetProgressUpdateDelayMS() const override { return GetUpdateDelayMS(m_Mode); }
- virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); }
+ virtual std::unique_ptr<ProgressBase::ProgressBar> CreateProgressBar(std::string_view InSubTask) override
+ {
+ return std::make_unique<zen::ProgressBar>(m_Mode, InSubTask);
+ }
private:
zen::ProgressBar::Mode m_Mode;
};
-OperationLogOutput*
-CreateConsoleLogOutput(ProgressBar::Mode InMode)
+ProgressBase*
+CreateConsoleProgress(ProgressBar::Mode InMode)
{
return new ConsoleOpLogOutput(InMode);
}
diff --git a/src/zen/progressbar.h b/src/zen/progressbar.h
index b54c009e1..26bb9b9c4 100644
--- a/src/zen/progressbar.h
+++ b/src/zen/progressbar.h
@@ -4,46 +4,13 @@
#include <zencore/timer.h>
#include <zencore/zencore.h>
-
-#include <string>
+#include <zenutil/progress.h>
namespace zen {
-class OperationLogOutput;
-
-class ProgressBar
+class ProgressBar : public ProgressBase::ProgressBar
{
public:
- struct State
- {
- bool operator==(const State&) const = default;
- std::string Task;
- std::string Details;
- uint64_t TotalCount = 0;
- uint64_t RemainingCount = 0;
- uint64_t OptionalElapsedTime = (uint64_t)-1;
- enum class EStatus
- {
- Running,
- Aborted,
- Paused
- };
- EStatus Status = EStatus::Running;
-
- static EStatus CalculateStatus(bool IsAborted, bool IsPaused)
- {
- if (IsAborted)
- {
- return EStatus::Aborted;
- }
- if (IsPaused)
- {
- return EStatus::Paused;
- }
- return EStatus::Running;
- }
- };
-
enum class Mode
{
Plain,
@@ -60,24 +27,30 @@ public:
explicit ProgressBar(Mode InMode, std::string_view InSubTask);
~ProgressBar();
- void UpdateState(const State& NewState, bool DoLinebreak);
+ void UpdateState(const State& NewState, bool DoLinebreak) override;
void ForceLinebreak();
- void Finish();
+ void Finish() override;
bool IsSameTask(std::string_view Task) const;
bool HasActiveTask() const;
private:
+ void SetupScrollRegion();
+ void TeardownScrollRegion();
+ void RenderStatusLine(std::string_view Line);
+
const Mode m_Mode;
Stopwatch m_SW;
uint64_t m_LastUpdateMS;
uint64_t m_PausedMS;
State m_State;
const std::string m_SubTask;
- size_t m_LastOutputLength = 0;
+ size_t m_LastOutputLength = 0;
+ bool m_ScrollRegionActive = false;
+ uint32_t m_ScrollRegionRows = 0;
};
uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode);
-OperationLogOutput* CreateConsoleLogOutput(ProgressBar::Mode InMode);
+ProgressBase* CreateConsoleProgress(ProgressBar::Mode InMode);
} // namespace zen
diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp
index 3277eb856..0229db4a8 100644
--- a/src/zen/zen.cpp
+++ b/src/zen/zen.cpp
@@ -9,6 +9,7 @@
#include "cmds/bench_cmd.h"
#include "cmds/builds_cmd.h"
#include "cmds/cache_cmd.h"
+#include "cmds/compute_cmd.h"
#include "cmds/copy_cmd.h"
#include "cmds/dedup_cmd.h"
#include "cmds/exec_cmd.h"
@@ -368,7 +369,7 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char**
}
}
- // Parse subcommand args permissively — unrecognised options are collected
+ // Parse subcommand args permissively - unrecognised options are collected
// and forwarded to the parent parser so that parent options (e.g. --path)
// can appear after the subcommand name on the command line.
std::vector<std::string> SubUnmatched;
@@ -588,7 +589,8 @@ main(int argc, char** argv)
DropCommand DropCmd;
DropProjectCommand ProjectDropCmd;
#if ZEN_WITH_COMPUTE_SERVICES
- ExecCommand ExecCmd;
+ ComputeCommand ComputeCmd;
+ ExecCommand ExecCmd;
#endif // ZEN_WITH_COMPUTE_SERVICES
ExportOplogCommand ExportOplogCmd;
FlushCommand FlushCmd;
@@ -649,6 +651,7 @@ main(int argc, char** argv)
{DownCommand::Name, &DownCmd, DownCommand::Description},
{DropCommand::Name, &DropCmd, DropCommand::Description},
#if ZEN_WITH_COMPUTE_SERVICES
+ {ComputeCommand::Name, &ComputeCmd, ComputeCommand::Description},
{ExecCommand::Name, &ExecCmd, ExecCommand::Description},
#endif
{GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description},
diff --git a/src/zenbase/include/zenbase/atomic.h b/src/zenbase/include/zenbase/atomic.h
deleted file mode 100644
index 4ad7962cf..000000000
--- a/src/zenbase/include/zenbase/atomic.h
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zenbase/zenbase.h>
-
-#if ZEN_COMPILER_MSC
-# include <intrin.h>
-#else
-# include <atomic>
-#endif
-
-#include <cinttypes>
-
-namespace zen {
-
-inline uint32_t
-AtomicIncrement(volatile uint32_t& value)
-{
-#if ZEN_COMPILER_MSC
- return _InterlockedIncrement((long volatile*)&value);
-#else
- return ((std::atomic<uint32_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1;
-#endif
-}
-inline uint32_t
-AtomicDecrement(volatile uint32_t& value)
-{
-#if ZEN_COMPILER_MSC
- return _InterlockedDecrement((long volatile*)&value);
-#else
- return ((std::atomic<uint32_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1;
-#endif
-}
-
-inline uint64_t
-AtomicIncrement(volatile uint64_t& value)
-{
-#if ZEN_COMPILER_MSC
- return _InterlockedIncrement64((__int64 volatile*)&value);
-#else
- return ((std::atomic<uint64_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1;
-#endif
-}
-inline uint64_t
-AtomicDecrement(volatile uint64_t& value)
-{
-#if ZEN_COMPILER_MSC
- return _InterlockedDecrement64((__int64 volatile*)&value);
-#else
- return ((std::atomic<uint64_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1;
-#endif
-}
-
-inline uint32_t
-AtomicAdd(volatile uint32_t& value, uint32_t amount)
-{
-#if ZEN_COMPILER_MSC
- return _InterlockedExchangeAdd((long volatile*)&value, amount);
-#else
- return ((std::atomic<uint32_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst);
-#endif
-}
-inline uint64_t
-AtomicAdd(volatile uint64_t& value, uint64_t amount)
-{
-#if ZEN_COMPILER_MSC
- return _InterlockedExchangeAdd64((__int64 volatile*)&value, amount);
-#else
- return ((std::atomic<uint64_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst);
-#endif
-}
-
-} // namespace zen
diff --git a/src/zenbase/include/zenbase/refcount.h b/src/zenbase/include/zenbase/refcount.h
index 08bc6ae54..0da78ad91 100644
--- a/src/zenbase/include/zenbase/refcount.h
+++ b/src/zenbase/include/zenbase/refcount.h
@@ -1,9 +1,9 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
-#include <zenbase/atomic.h>
#include <zenbase/concepts.h>
+#include <atomic>
#include <compare>
namespace zen {
@@ -13,6 +13,13 @@ namespace zen {
*
* When the reference count reaches zero, the object deletes itself. This class relies
* on having a virtual destructor to ensure proper cleanup of derived classes.
+ *
+ * Release() is marked noexcept. Derived destructors are expected not to throw - if a
+ * derived destructor does throw, the exception crosses the noexcept boundary and the
+ * program terminates via std::terminate. This matches C++'s default destructor
+ * behaviour (destructors are implicitly noexcept) but is spelled out explicitly here
+ * because the delete happens inside Release() rather than at the call site, so the
+ * terminate point is not visually obvious.
*/
class RefCounted
{
@@ -20,10 +27,15 @@ public:
RefCounted() = default;
virtual ~RefCounted() = default;
- inline uint32_t AddRef() const noexcept { return AtomicIncrement(const_cast<RefCounted*>(this)->m_RefCount); }
- inline uint32_t Release() const
+ // AddRef uses relaxed ordering: a thread can only add a reference to an object it
+ // already has a reference to, so there is nothing that needs to synchronize here.
+ // Release uses acq_rel so that (a) all prior modifications of the object made on
+ // other threads happen-before the destructor, and (b) the ref-count decrement is
+ // visible to any thread that observes the new count.
+ inline uint32_t AddRef() const noexcept { return m_RefCount.fetch_add(1, std::memory_order_relaxed) + 1; }
+ inline uint32_t Release() const noexcept
{
- const uint32_t RefCount = AtomicDecrement(const_cast<RefCounted*>(this)->m_RefCount);
+ const uint32_t RefCount = m_RefCount.fetch_sub(1, std::memory_order_acq_rel) - 1;
if (RefCount == 0)
{
delete this;
@@ -39,10 +51,12 @@ public:
RefCounted& operator=(RefCounted&&) = delete;
protected:
- inline uint32_t RefCount() const { return m_RefCount; }
+ // Diagnostic accessor: relaxed load is fine because the returned value is already
+ // stale the moment it is observed, so no extra ordering would make it more reliable.
+ inline uint32_t RefCount() const noexcept { return m_RefCount.load(std::memory_order_relaxed); }
private:
- uint32_t m_RefCount = 0;
+ mutable std::atomic<uint32_t> m_RefCount = 0;
};
/**
@@ -52,6 +66,10 @@ private:
* It has no virtual destructor, so it's important that you either don't derive from it further,
* or ensure that the derived class has a virtual destructor.
*
+ * As with RefCounted, Release() is noexcept and the derived destructor must not throw.
+ * A throwing destructor will cause std::terminate to be called when the refcount hits
+ * zero.
+ *
* This class is useful when you want to avoid adding a vtable to a class just to implement
* reference counting.
*/
@@ -60,15 +78,17 @@ template<typename T>
class TRefCounted
{
public:
- TRefCounted() = default;
- ~TRefCounted() = default;
+ TRefCounted() = default;
- inline uint32_t AddRef() const noexcept { return AtomicIncrement(const_cast<TRefCounted<T>*>(this)->m_RefCount); }
- inline uint32_t Release() const
+ // See RefCounted::AddRef/Release for ordering rationale.
+ inline uint32_t AddRef() const noexcept { return m_RefCount.fetch_add(1, std::memory_order_relaxed) + 1; }
+ inline uint32_t Release() const noexcept
{
- const uint32_t RefCount = AtomicDecrement(const_cast<TRefCounted<T>*>(this)->m_RefCount);
+ const uint32_t RefCount = m_RefCount.fetch_sub(1, std::memory_order_acq_rel) - 1;
if (RefCount == 0)
{
+ // DeleteThis may be overridden as a non-const member in derived types,
+ // so we cast away const to support both signatures.
const_cast<T*>(static_cast<const T*>(this))->DeleteThis();
}
return RefCount;
@@ -82,92 +102,21 @@ public:
TRefCounted& operator=(TRefCounted&&) = delete;
protected:
- inline uint32_t RefCount() const { return m_RefCount; }
-
- void DeleteThis() const noexcept { delete static_cast<const T*>(this); }
-
-private:
- uint32_t m_RefCount = 0;
-};
-
-/**
- * Smart pointer for classes derived from RefCounted
- */
-
-template<class T>
-class RefPtr
-{
-public:
- inline RefPtr() = default;
- inline RefPtr(const RefPtr& Rhs) : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); }
- inline RefPtr(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); }
- inline ~RefPtr() { m_Ref && m_Ref->Release(); }
-
- [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; }
- inline explicit operator bool() const { return m_Ref != nullptr; }
- inline operator T*() const { return m_Ref; }
- inline T* operator->() const { return m_Ref; }
+ // Non-virtual: destruction goes through DeleteThis(), which deletes the derived type.
+ // Protected so that external callers cannot `delete` through a TRefCounted<T>* and slice.
+ ~TRefCounted() = default;
- inline std::strong_ordering operator<=>(const RefPtr& Rhs) const = default;
+ // Diagnostic accessor: see RefCounted::RefCount for ordering rationale.
+ inline uint32_t RefCount() const noexcept { return m_RefCount.load(std::memory_order_relaxed); }
- inline RefPtr& operator=(T* Rhs)
- {
- Rhs && Rhs->AddRef();
- m_Ref && m_Ref->Release();
- m_Ref = Rhs;
- return *this;
- }
- inline RefPtr& operator=(const RefPtr& Rhs)
- {
- if (&Rhs != this)
- {
- Rhs && Rhs->AddRef();
- m_Ref && m_Ref->Release();
- m_Ref = Rhs.m_Ref;
- }
- return *this;
- }
- inline RefPtr& operator=(RefPtr&& Rhs) noexcept
- {
- if (&Rhs != this)
- {
- m_Ref && m_Ref->Release();
- m_Ref = Rhs.m_Ref;
- Rhs.m_Ref = nullptr;
- }
- return *this;
- }
- template<typename OtherType>
- inline RefPtr& operator=(RefPtr<OtherType>&& Rhs) noexcept
- {
- if ((RefPtr*)&Rhs != this)
- {
- m_Ref && m_Ref->Release();
- m_Ref = Rhs.m_Ref;
- Rhs.m_Ref = nullptr;
- }
- return *this;
- }
- inline RefPtr(RefPtr&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; }
- template<typename OtherType>
- explicit inline RefPtr(RefPtr<OtherType>&& Rhs) noexcept : m_Ref(Rhs.m_Ref)
- {
- Rhs.m_Ref = nullptr;
- }
-
- inline void Swap(RefPtr& Rhs) noexcept { std::swap(m_Ref, Rhs.m_Ref); }
+ void DeleteThis() const noexcept { delete static_cast<const T*>(this); }
private:
- T* m_Ref = nullptr;
- template<typename U>
- friend class RefPtr;
+ mutable std::atomic<uint32_t> m_RefCount = 0;
};
/**
- * Smart pointer for classes derived from RefCounted
- *
- * This variant does not decay to a raw pointer
- *
+ * Smart pointer for classes derived from RefCounted (or TRefCounted)
*/
template<class T>
@@ -177,12 +126,16 @@ public:
inline Ref() = default;
inline Ref(Ref&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; }
inline Ref(const Ref& Rhs) noexcept : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); }
- inline explicit Ref(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); }
- inline ~Ref() { m_Ref && m_Ref->Release(); }
+ inline explicit Ref(T* Ptr) noexcept : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); }
+ inline ~Ref() noexcept { m_Ref && m_Ref->Release(); }
template<typename DerivedType>
requires DerivedFrom<DerivedType, T>
- inline Ref(const Ref<DerivedType>& Rhs) : Ref(Rhs.m_Ref) {}
+ inline Ref(const Ref<DerivedType>& Rhs) noexcept : Ref(Rhs.m_Ref) {}
+
+ template<typename DerivedType>
+ requires DerivedFrom<DerivedType, T>
+ inline Ref(Ref<DerivedType>&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; }
[[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; }
inline explicit operator bool() const { return m_Ref != nullptr; }
@@ -192,14 +145,14 @@ public:
inline std::strong_ordering operator<=>(const Ref& Rhs) const = default;
- inline Ref& operator=(T* Rhs)
+ inline Ref& operator=(T* Rhs) noexcept
{
Rhs && Rhs->AddRef();
m_Ref && m_Ref->Release();
m_Ref = Rhs;
return *this;
}
- inline Ref& operator=(const Ref& Rhs)
+ inline Ref& operator=(const Ref& Rhs) noexcept
{
if (&Rhs != this)
{
@@ -219,6 +172,20 @@ public:
}
return *this;
}
+ template<typename DerivedType>
+ requires DerivedFrom<DerivedType, T>
+ inline Ref& operator=(Ref<DerivedType>&& Rhs) noexcept
+ {
+ if ((Ref*)&Rhs != this)
+ {
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs.m_Ref;
+ Rhs.m_Ref = nullptr;
+ }
+ return *this;
+ }
+
+ inline void Swap(Ref& Rhs) noexcept { std::swap(m_Ref, Rhs.m_Ref); }
private:
T* m_Ref = nullptr;
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
index a1a39fc3c..bb574edc2 100644
--- a/src/zencompute/CLAUDE.md
+++ b/src/zencompute/CLAUDE.md
@@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc
**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure.
-**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit.
+**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit.
**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent.
@@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra
- `ActiveLsns` — for cancellation lookup (under `m_Lock`)
- `FinishedLsns` — moved here when actions complete
- `IdleSince` — used for 15-minute automatic expiry
-- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit
+- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default
**Queue state machine (`QueueState` enum):**
```
@@ -216,11 +216,14 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han
## Concurrency Model
-**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks:
-1. `m_ResultsLock`
-2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock")
-3. `m_PendingLock`
-4. `m_QueueLock`
+**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock.
+
+**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks:
+1. `m_ActionMapLock` (session action maps)
+2. `QueueEntry::m_Lock` (per-queue state)
+3. `m_ActionHistoryLock` (action history ring)
+
+Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`).
**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`.
diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp
index eb4c05f9f..f1df18e8e 100644
--- a/src/zencompute/cloudmetadata.cpp
+++ b/src/zencompute/cloudmetadata.cpp
@@ -183,7 +183,7 @@ CloudMetadata::TryDetectAWS()
m_Info.AvailabilityZone = std::string(AzResponse.AsText());
}
- // "spot" vs "on-demand" — determines whether the instance can be
+ // "spot" vs "on-demand" - determines whether the instance can be
// reclaimed by AWS with a 2-minute warning
HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders);
if (LifecycleResponse.IsSuccess())
@@ -273,7 +273,7 @@ CloudMetadata::TryDetectAzure()
std::string Priority = Compute["priority"].string_value();
m_Info.IsSpot = (Priority == "Spot");
- // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling
+ // Check if part of a VMSS (Virtual Machine Scale Set) - indicates autoscaling
std::string VmssName = Compute["vmScaleSetName"].string_value();
m_Info.IsAutoscaling = !VmssName.empty();
@@ -609,7 +609,7 @@ namespace zen::compute {
TEST_SUITE_BEGIN("compute.cloudmetadata");
// ---------------------------------------------------------------------------
-// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService
+// Test helper - spins up a local ASIO HTTP server hosting a MockImdsService
// ---------------------------------------------------------------------------
struct TestImdsServer
@@ -974,7 +974,7 @@ TEST_CASE("cloudmetadata.sentinel_files")
SUBCASE("only failed providers get sentinels")
{
- // Switch to AWS — Azure and GCP never probed, so no sentinels for them
+ // Switch to AWS - Azure and GCP never probed, so no sentinels for them
Imds.Mock.ActiveProvider = CloudProvider::AWS;
auto Cloud = Imds.CreateCloud();
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
index 92901de64..7f354a51c 100644
--- a/src/zencompute/computeservice.cpp
+++ b/src/zencompute/computeservice.cpp
@@ -8,6 +8,8 @@
# include "recording/actionrecorder.h"
# include "runners/localrunner.h"
# include "runners/remotehttprunner.h"
+# include "runners/managedrunner.h"
+# include "pathvalidation.h"
# if ZEN_PLATFORM_LINUX
# include "runners/linuxrunner.h"
# elif ZEN_PLATFORM_WINDOWS
@@ -119,6 +121,8 @@ struct ComputeServiceSession::Impl
, m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
, m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
{
+ m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool);
+
// Create a non-expiring, non-deletable implicit queue for legacy endpoints
auto Result = CreateQueue("implicit"sv, {}, {});
m_ImplicitQueueId = Result.QueueId;
@@ -195,13 +199,9 @@ struct ComputeServiceSession::Impl
std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr};
- RwLock m_PendingLock;
- std::map<int, Ref<RunnerAction>> m_PendingActions;
-
- RwLock m_RunningLock;
+ RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap
+ std::map<int, Ref<RunnerAction>> m_PendingActions;
std::unordered_map<int, Ref<RunnerAction>> m_RunningMap;
-
- RwLock m_ResultsLock;
std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap;
metrics::Meter m_ResultRate;
std::atomic<uint64_t> m_RetiredCount{0};
@@ -242,8 +242,9 @@ struct ComputeServiceSession::Impl
// Recording
- void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
- void StopRecording();
+ bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
+ bool StopRecording();
+ bool IsRecording() const;
std::unique_ptr<ActionRecorder> m_Recorder;
@@ -343,9 +344,12 @@ struct ComputeServiceSession::Impl
ActionCounts GetActionCounts()
{
ActionCounts Counts;
- Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
- Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load();
+ m_ActionMapLock.WithSharedLock([&] {
+ Counts.Pending = (int)m_PendingActions.size();
+ Counts.Running = (int)m_RunningMap.size();
+ Counts.Completed = (int)m_ResultsMap.size();
+ });
+ Counts.Completed += (int)m_RetiredCount.load();
Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] {
size_t Count = 0;
for (const auto& [Id, Queue] : m_Queues)
@@ -364,8 +368,10 @@ struct ComputeServiceSession::Impl
{
Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed));
m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); });
- m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); });
- m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); });
+ m_ActionMapLock.WithSharedLock([&] {
+ Cbo << "actions_complete"sv << m_ResultsMap.size();
+ Cbo << "actions_pending"sv << m_PendingActions.size();
+ });
Cbo << "actions_submitted"sv << GetSubmittedActionCount();
EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo);
EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo);
@@ -443,34 +449,24 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState)
return true;
}
- // CAS failed, Current was updated — retry with the new value
+ // CAS failed, Current was updated - retry with the new value
}
}
void
ComputeServiceSession::Impl::AbandonAllActions()
{
- // Collect all pending actions and mark them as Abandoned
+ // Collect all pending and running actions under a single lock scope
std::vector<Ref<RunnerAction>> PendingToAbandon;
+ std::vector<Ref<RunnerAction>> RunningToAbandon;
- m_PendingLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
PendingToAbandon.reserve(m_PendingActions.size());
for (auto& [Lsn, Action] : m_PendingActions)
{
PendingToAbandon.push_back(Action);
}
- });
-
- for (auto& Action : PendingToAbandon)
- {
- Action->SetActionState(RunnerAction::State::Abandoned);
- }
-
- // Collect all running actions and mark them as Abandoned, then
- // best-effort cancel via the local runner group
- std::vector<Ref<RunnerAction>> RunningToAbandon;
- m_RunningLock.WithSharedLock([&] {
RunningToAbandon.reserve(m_RunningMap.size());
for (auto& [Lsn, Action] : m_RunningMap)
{
@@ -478,6 +474,11 @@ ComputeServiceSession::Impl::AbandonAllActions()
}
});
+ for (auto& Action : PendingToAbandon)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ }
+
for (auto& Action : RunningToAbandon)
{
Action->SetActionState(RunnerAction::State::Abandoned);
@@ -617,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
m_KnownWorkerUris.insert(UriStr);
auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool);
+ NewRunner->SetRemoteHostname(Hostname);
SyncWorkersToRunner(*NewRunner);
m_RemoteRunnerGroup.AddRunner(NewRunner);
}
@@ -718,31 +720,51 @@ ComputeServiceSession::Impl::ShutdownRunners()
m_RemoteRunnerGroup.Shutdown();
}
-void
+bool
ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath)
{
+ if (m_Recorder)
+ {
+ ZEN_WARN("recording is already active");
+ return false;
+ }
+
ZEN_INFO("starting recording to '{}'", RecordingPath);
m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath);
ZEN_INFO("started recording to '{}'", RecordingPath);
+ return true;
}
-void
+bool
ComputeServiceSession::Impl::StopRecording()
{
+ if (!m_Recorder)
+ {
+ ZEN_WARN("no recording is active");
+ return false;
+ }
+
ZEN_INFO("stopping recording");
m_Recorder = nullptr;
ZEN_INFO("stopped recording");
+ return true;
+}
+
+bool
+ComputeServiceSession::Impl::IsRecording() const
+{
+ return m_Recorder != nullptr;
}
std::vector<ComputeServiceSession::RunningActionInfo>
ComputeServiceSession::Impl::GetRunningActions()
{
std::vector<ComputeServiceSession::RunningActionInfo> Result;
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
Result.reserve(m_RunningMap.size());
for (const auto& [Lsn, Action] : m_RunningMap)
{
@@ -810,6 +832,11 @@ void
ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker)
{
ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker");
+
+ // Validate all paths in the worker description upfront, before the worker is
+ // distributed to runners. This rejects malicious packages early at ingestion time.
+ ValidateWorkerDescriptionPaths(Worker.GetObject());
+
RwLock::ExclusiveLockScope _(m_WorkerLock);
const IoHash& WorkerId = Worker.GetObject().GetHash();
@@ -994,10 +1021,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke
Pending->ActionObj = ActionObj;
Pending->Priority = RequestPriority;
- // For now simply put action into pending state, so we can do batch scheduling
+ // Insert into the pending map immediately so the action is visible to
+ // FindActionResult/GetActionResult right away. SetActionState will call
+ // PostUpdate which adds the action to m_UpdatedActions and signals the
+ // scheduler, but the scheduler's HandleActionUpdates inserts with
+ // std::map::insert which is a no-op for existing keys.
ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); });
Pending->SetActionState(RunnerAction::State::Pending);
if (m_Recorder)
@@ -1043,11 +1075,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount()
HttpResponseCode
ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage)
{
- // This lock is held for the duration of the function since we need to
- // be sure that the action doesn't change state while we are checking the
- // different data structures
-
- RwLock::ExclusiveLockScope _(m_ResultsLock);
+ RwLock::ExclusiveLockScope _(m_ActionMapLock);
if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end())
{
@@ -1058,25 +1086,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult
return HttpResponseCode::OK;
}
+ if (m_PendingActions.find(ActionLsn) != m_PendingActions.end())
{
- RwLock::SharedLockScope __(m_PendingLock);
-
- if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end())
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
- // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
- // always be taken after m_ResultsLock if both are needed
-
+ if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
{
- RwLock::SharedLockScope __(m_RunningLock);
-
- if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
return HttpResponseCode::NotFound;
@@ -1085,11 +1102,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult
HttpResponseCode
ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage)
{
- // This lock is held for the duration of the function since we need to
- // be sure that the action doesn't change state while we are checking the
- // different data structures
-
- RwLock::ExclusiveLockScope _(m_ResultsLock);
+ RwLock::ExclusiveLockScope _(m_ActionMapLock);
for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It)
{
@@ -1103,30 +1116,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage&
}
}
+ for (const auto& [K, Pending] : m_PendingActions)
{
- RwLock::SharedLockScope __(m_PendingLock);
-
- for (const auto& [K, Pending] : m_PendingActions)
+ if (Pending->ActionId == ActionId)
{
- if (Pending->ActionId == ActionId)
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
}
- // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
- // always be taken after m_ResultsLock if both are needed
-
+ for (const auto& [K, v] : m_RunningMap)
{
- RwLock::SharedLockScope __(m_RunningLock);
-
- for (const auto& [K, v] : m_RunningMap)
+ if (v->ActionId == ActionId)
{
- if (v->ActionId == ActionId)
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
}
@@ -1144,12 +1146,16 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo)
{
Cbo.BeginArray("completed");
- m_ResultsLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (auto& [Lsn, Action] : m_ResultsMap)
{
Cbo.BeginObject();
Cbo << "lsn"sv << Lsn;
Cbo << "state"sv << RunnerAction::ToString(Action->ActionState());
+ if (!Action->FailureReason.empty())
+ {
+ Cbo << "reason"sv << Action->FailureReason;
+ }
Cbo.EndObject();
}
});
@@ -1275,20 +1281,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId)
std::vector<Ref<RunnerAction>> PendingActionsToCancel;
std::vector<int> RunningLsnsToCancel;
- m_PendingLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (int Lsn : LsnsToCancel)
{
if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end())
{
PendingActionsToCancel.push_back(It->second);
}
- }
- });
-
- m_RunningLock.WithSharedLock([&] {
- for (int Lsn : LsnsToCancel)
- {
- if (m_RunningMap.find(Lsn) != m_RunningMap.end())
+ else if (m_RunningMap.find(Lsn) != m_RunningMap.end())
{
RunningLsnsToCancel.push_back(Lsn);
}
@@ -1307,7 +1307,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId)
// transition from the runner is blocked (Cancelled > Failed in the enum).
for (int Lsn : RunningLsnsToCancel)
{
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end())
{
It->second->SetActionState(RunnerAction::State::Cancelled);
@@ -1444,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo)
if (Queue)
{
- Queue->m_Lock.WithSharedLock([&] {
- m_ResultsLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
+ Queue->m_Lock.WithSharedLock([&] {
for (int Lsn : Queue->FinishedLsns)
{
if (m_ResultsMap.contains(Lsn))
@@ -1475,15 +1475,19 @@ ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, Run
return;
}
+ bool WasActive = false;
Queue->m_Lock.WithExclusiveLock([&] {
- Queue->ActiveLsns.erase(Lsn);
+ WasActive = Queue->ActiveLsns.erase(Lsn) > 0;
Queue->FinishedLsns.insert(Lsn);
});
- const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
- if (PreviousActive == 1)
+ if (WasActive)
{
- Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
+ if (PreviousActive == 1)
+ {
+ Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ }
}
switch (ActionState)
@@ -1541,26 +1545,32 @@ ComputeServiceSession::Impl::SchedulePendingActions()
{
ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions");
int ScheduledCount = 0;
- size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
- size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); });
+ size_t RunningCount = 0;
+ size_t PendingCount = 0;
+ size_t ResultCount = 0;
+
+ m_ActionMapLock.WithSharedLock([&] {
+ RunningCount = m_RunningMap.size();
+ PendingCount = m_PendingActions.size();
+ ResultCount = m_ResultsMap.size();
+ });
static Stopwatch DumpRunningTimer;
auto _ = MakeGuard([&] {
- ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
- ScheduledCount,
- RunningCount,
- m_RetiredCount.load(),
- PendingCount,
- ResultCount);
+ ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
+ ScheduledCount,
+ RunningCount,
+ m_RetiredCount.load(),
+ PendingCount,
+ ResultCount);
if (DumpRunningTimer.GetElapsedTimeMs() > 30000)
{
DumpRunningTimer.Reset();
std::set<int> RunningList;
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (auto& [K, V] : m_RunningMap)
{
RunningList.insert(K);
@@ -1602,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions()
// Also note that the m_PendingActions list is not maintained
// here, that's done periodically in SchedulePendingActions()
- m_PendingLock.WithExclusiveLock([&] {
- if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
- {
- return;
- }
+ // Extract pending actions under a shared lock - we only need to read
+ // the map and take Ref copies. ActionState() is atomic so this is safe.
+ // Sorting and capacity trimming happen outside the lock to avoid
+ // blocking HTTP handlers on O(N log N) work with large pending queues.
- if (m_PendingActions.empty())
+ m_ActionMapLock.WithSharedLock([&] {
+ if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
{
return;
}
@@ -1628,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions()
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
case RunnerAction::State::Abandoned:
+ case RunnerAction::State::Rejected:
case RunnerAction::State::Cancelled:
break;
@@ -1638,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions()
}
}
- // Sort by priority descending, then by LSN ascending (FIFO within same priority)
- std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
- if (A->Priority != B->Priority)
- {
- return A->Priority > B->Priority;
- }
- return A->ActionLsn < B->ActionLsn;
- });
+ PendingCount = m_PendingActions.size();
+ });
- if (ActionsToSchedule.size() > Capacity)
+ // Sort by priority descending, then by LSN ascending (FIFO within same priority)
+ std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
+ if (A->Priority != B->Priority)
{
- ActionsToSchedule.resize(Capacity);
+ return A->Priority > B->Priority;
}
-
- PendingCount = m_PendingActions.size();
+ return A->ActionLsn < B->ActionLsn;
});
+ if (ActionsToSchedule.size() > Capacity)
+ {
+ ActionsToSchedule.resize(Capacity);
+ }
+
if (ActionsToSchedule.empty())
{
_.Dismiss();
return;
}
- ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size());
+ ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size());
Stopwatch SubmitTimer;
std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule);
@@ -1681,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions()
}
}
- ZEN_INFO("scheduled {} pending actions in {} ({} rejected)",
- ScheduledActionCount,
- NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
- NotAcceptedCount);
+ ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)",
+ ScheduledActionCount,
+ NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
+ NotAcceptedCount);
ScheduledCount += ScheduledActionCount;
PendingCount -= ScheduledActionCount;
@@ -1701,7 +1712,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction()
{
int TimeoutMs = 500;
- auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+ auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); });
if (PendingCount)
{
@@ -1720,22 +1731,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction()
m_SchedulingThreadEvent.Reset();
}
- ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}",
- m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }),
- PendingCount,
- m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }),
- m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }),
- TimeoutMs);
+ m_ActionMapLock.WithSharedLock([&] {
+ ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}",
+ m_PendingActions.size(),
+ m_RunningMap.size(),
+ m_ResultsMap.size(),
+ TimeoutMs);
+ });
HandleActionUpdates();
- // Auto-transition Draining → Paused when all work is done
+ // Auto-transition Draining -> Paused when all work is done
if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining)
{
- size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
+ bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); });
- if (Pending == 0 && Running == 0)
+ if (AllDrained)
{
SessionState Expected = SessionState::Draining;
if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel))
@@ -1776,9 +1787,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId)
if (Config)
{
- int Value = Config["max_retries"].AsInt32(0);
+ int Value = Config["max_retries"].AsInt32(-1);
- if (Value > 0)
+ if (Value >= 0)
{
return Value;
}
@@ -1797,7 +1808,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
// Find, validate, and remove atomically under a single lock scope to prevent
// concurrent RescheduleAction calls from double-removing the same action.
- m_ResultsLock.WithExclusiveLock([&] {
+ m_ActionMapLock.WithExclusiveLock([&] {
auto It = m_ResultsMap.find(ActionLsn);
if (It == m_ResultsMap.end())
{
@@ -1855,7 +1866,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
}
}
- // Reset action state — this calls PostUpdate() internally
+ // Reset action state - this calls PostUpdate() internally
Action->ResetActionStateToPending();
int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed);
@@ -1871,26 +1882,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn)
bool WasRunning = false;
// Look for the action in pending or running maps
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end())
{
Action = It->second;
WasRunning = true;
}
+ else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end())
+ {
+ Action = PIt->second;
+ }
});
if (!Action)
{
- m_PendingLock.WithSharedLock([&] {
- if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end())
- {
- Action = It->second;
- }
- });
- }
-
- if (!Action)
- {
return {.Success = false, .Error = "Action not found in pending or running maps"};
}
@@ -1912,18 +1917,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn)
void
ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn)
{
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ // Caller must hold m_ActionMapLock exclusively.
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
}
void
@@ -1946,7 +1948,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
// This is safe because state transitions are monotonically increasing by enum
// rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so
// SetActionState rejects any transition to a lower-ranked state. By the time
- // we read ActionState() here, it reflects the highest state reached — making
+ // we read ActionState() here, it reflects the highest state reached - making
// the first occurrence per LSN authoritative and duplicates redundant.
for (Ref<RunnerAction>& Action : UpdatedActions)
{
@@ -1956,7 +1958,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
{
switch (Action->ActionState())
{
- // Newly enqueued — add to pending map for scheduling
+ // Newly enqueued - add to pending map for scheduling
case RunnerAction::State::Pending:
// Guard against a race where the session is abandoned between
// EnqueueAction (which calls PostUpdate) and this scheduler
@@ -1973,35 +1975,44 @@ ComputeServiceSession::Impl::HandleActionUpdates()
}
else
{
- m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
}
break;
- // Async submission in progress — remains in pending map
+ // Async submission in progress - remains in pending map
case RunnerAction::State::Submitting:
break;
- // Dispatched to a runner — move from pending to running
+ // Dispatched to a runner - move from pending to running
case RunnerAction::State::Running:
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- m_RunningMap.insert({ActionLsn, Action});
- m_PendingActions.erase(ActionLsn);
- });
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.insert({ActionLsn, Action});
+ m_PendingActions.erase(ActionLsn);
});
ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn);
break;
- // Retracted — pull back for rescheduling without counting against retry limit
+ // Retracted - pull back for rescheduling without counting against retry limit
case RunnerAction::State::Retracted:
{
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.erase(ActionLsn);
+ m_PendingActions[ActionLsn] = Action;
+ });
Action->ResetActionStateToPending();
ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn);
break;
}
- // Terminal states — move to results, record history, notify queue
+ // Rejected - runner was at capacity, reschedule without retry cost
+ case RunnerAction::State::Rejected:
+ {
+ Action->ResetActionStateToPending();
+ ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn);
+ break;
+ }
+
+ // Terminal states - move to results, record history, notify queue
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
case RunnerAction::State::Abandoned:
@@ -2010,7 +2021,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
auto TerminalState = Action->ActionState();
// Automatic retry for Failed/Abandoned actions with retries remaining.
- // Skip retries when the session itself is abandoned — those actions
+ // Skip retries when the session itself is abandoned - those actions
// were intentionally abandoned and should not be rescheduled.
if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) &&
m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned)
@@ -2019,7 +2030,10 @@ ComputeServiceSession::Impl::HandleActionUpdates()
if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries)
{
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.erase(ActionLsn);
+ m_PendingActions[ActionLsn] = Action;
+ });
// Reset triggers PostUpdate() which re-enters the action as Pending
Action->ResetActionStateToPending();
@@ -2032,18 +2046,26 @@ ComputeServiceSession::Impl::HandleActionUpdates()
MaxRetries);
break;
}
+ else
+ {
+ ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling",
+ Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(TerminalState),
+ Action->RetryCount.load(std::memory_order_relaxed));
+ }
}
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ RemoveActionFromActiveMaps(ActionLsn);
- // Update queue counters BEFORE publishing the result into
- // m_ResultsMap. GetActionResult erases from m_ResultsMap
- // under m_ResultsLock, so if we updated counters after
- // releasing that lock, a caller could observe ActiveCount
- // still at 1 immediately after GetActionResult returned OK.
- NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
+ // Update queue counters BEFORE publishing the result into
+ // m_ResultsMap. GetActionResult erases from m_ResultsMap
+ // under m_ActionMapLock, so if we updated counters after
+ // releasing that lock, a caller could observe ActiveCount
+ // still at 1 immediately after GetActionResult returned OK.
+ NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
- m_ResultsLock.WithExclusiveLock([&] {
m_ResultsMap[ActionLsn] = Action;
// Append to bounded action history ring
@@ -2124,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>&
ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions");
std::vector<SubmitResult> Results(Actions.size());
- // First try submitting the batch to local runners in parallel
+ // First try submitting the batch to local runners
std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions);
- std::vector<size_t> RemoteIndices;
std::vector<Ref<RunnerAction>> RemoteActions;
for (size_t i = 0; i < Actions.size(); ++i)
@@ -2138,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>&
}
else
{
- RemoteIndices.push_back(i);
RemoteActions.push_back(Actions[i]);
+ Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"};
}
}
- // Submit remaining actions to remote runners in parallel
+ // Dispatch remaining actions to remote runners asynchronously.
+ // Mark actions as Submitting so the scheduler won't re-pick them.
+ // The remote runner will transition them to Running on success, or
+ // we mark them Failed on rejection so HandleActionUpdates retries.
if (!RemoteActions.empty())
{
- std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
-
- for (size_t j = 0; j < RemoteIndices.size(); ++j)
+ for (const Ref<RunnerAction>& Action : RemoteActions)
{
- Results[RemoteIndices[j]] = std::move(RemoteResults[j]);
+ Action->SetActionState(RunnerAction::State::Submitting);
}
+
+ m_RemoteSubmitPool.ScheduleWork(
+ [this, RemoteActions = std::move(RemoteActions)]() {
+ std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
+
+ for (size_t j = 0; j < RemoteResults.size(); ++j)
+ {
+ if (!RemoteResults[j].IsAccepted)
+ {
+ ZEN_DEBUG("remote submission rejected for action {} ({}): {}",
+ RemoteActions[j]->ActionId,
+ RemoteActions[j]->ActionLsn,
+ RemoteResults[j].Reason);
+
+ RemoteActions[j]->SetActionState(RunnerAction::State::Rejected);
+ }
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
}
return Results;
@@ -2217,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged()
m_Impl->NotifyOrchestratorChanged();
}
-void
+bool
ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath)
{
- m_Impl->StartRecording(InResolver, RecordingPath);
+ return m_Impl->StartRecording(InResolver, RecordingPath);
}
-void
+bool
ComputeServiceSession::StopRecording()
{
- m_Impl->StopRecording();
+ return m_Impl->StopRecording();
+}
+
+bool
+ComputeServiceSession::IsRecording() const
+{
+ return m_Impl->IsRecording();
}
ComputeServiceSession::ActionCounts
@@ -2282,6 +2329,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files
}
void
+ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner");
+
+ auto* NewRunner =
+ new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions);
+
+ m_Impl->SyncWorkersToRunner(*NewRunner);
+ m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner);
+}
+
+void
ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName)
{
ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner");
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
index bdfd9d197..5ab189d89 100644
--- a/src/zencompute/httpcomputeservice.cpp
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -21,12 +21,14 @@
# include <zencore/thread.h>
# include <zencore/trace.h>
# include <zencore/uid.h>
-# include <zenstore/cidstore.h>
+# include <zenstore/hashkeyset.h>
+# include <zenstore/zenstore.h>
# include <zentelemetry/stats.h>
# include <algorithm>
# include <span>
# include <unordered_map>
+# include <utility>
# include <vector>
using namespace std::literals;
@@ -45,7 +47,9 @@ auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSe
struct HttpComputeService::Impl
{
HttpComputeService* m_Self;
- CidStore& m_CidStore;
+ ChunkStore& m_ActionStore;
+ ChunkStore& m_WorkerStore;
+ FallbackChunkResolver m_CombinedResolver;
IHttpStatsService& m_StatsService;
LoggerRef m_Log;
std::filesystem::path m_BaseDir;
@@ -58,6 +62,8 @@ struct HttpComputeService::Impl
RwLock m_WsConnectionsLock;
std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::function<void()> m_ShutdownCallback;
+
// Metrics
metrics::OperationTiming m_HttpRequests;
@@ -72,13 +78,13 @@ struct HttpComputeService::Impl
std::string ClientHostname; // empty if no hostname was provided
};
- // Remote queue registry — all three maps share the same RemoteQueueInfo objects.
+ // Remote queue registry - all three maps share the same RemoteQueueInfo objects.
// All maps are guarded by m_RemoteQueueLock.
RwLock m_RemoteQueueLock;
- std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info
- std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info
- std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info
+ std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token -> info
+ std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId -> info
+ std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key -> info
LoggerRef Log() { return m_Log; }
@@ -93,34 +99,38 @@ struct HttpComputeService::Impl
uint64_t NewBytes = 0;
};
- IngestStats IngestPackageAttachments(const CbPackage& Package);
- bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
- void HandleWorkersGet(HttpServerRequest& HttpReq);
- void HandleWorkersAllGet(HttpServerRequest& HttpReq);
- void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
- void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
- void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
+ bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats);
+ bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
+ bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment);
+ void HandleWorkersGet(HttpServerRequest& HttpReq);
+ void HandleWorkersAllGet(HttpServerRequest& HttpReq);
+ void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
+ void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
+ void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
// WebSocket / observer
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection);
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri);
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code);
void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions);
void RegisterRoutes();
- Impl(HttpComputeService* Self,
- CidStore& InCidStore,
- IHttpStatsService& StatsService,
- const std::filesystem::path& BaseDir,
- int32_t MaxConcurrentActions)
+ Impl(HttpComputeService* Self,
+ ChunkStore& InActionStore,
+ ChunkStore& InWorkerStore,
+ IHttpStatsService& StatsService,
+ std::filesystem::path BaseDir,
+ int32_t MaxConcurrentActions)
: m_Self(Self)
- , m_CidStore(InCidStore)
+ , m_ActionStore(InActionStore)
+ , m_WorkerStore(InWorkerStore)
+ , m_CombinedResolver(InActionStore, InWorkerStore)
, m_StatsService(StatsService)
, m_Log(logging::Get("compute"))
- , m_BaseDir(BaseDir)
- , m_ComputeService(InCidStore)
+ , m_BaseDir(std::move(BaseDir))
+ , m_ComputeService(m_CombinedResolver)
{
- m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions);
+ m_ComputeService.AddLocalRunner(m_CombinedResolver, m_BaseDir / "local", MaxConcurrentActions);
m_ComputeService.WaitUntilReady();
m_StatsService.RegisterHandler("compute", *m_Self);
RegisterRoutes();
@@ -182,6 +192,65 @@ HttpComputeService::Impl::RegisterRoutes()
HttpVerb::kPost);
m_Router.RegisterRoute(
+ "session/drain",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Draining from current state"sv;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "session/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ auto Counts = m_ComputeService.GetActionCounts();
+ Cbo << "actions_pending"sv << Counts.Pending;
+ Cbo << "actions_running"sv << Counts.Running;
+ Cbo << "actions_completed"sv << Counts.Completed;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "session/sunset",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+
+ if (m_ShutdownCallback)
+ {
+ m_ShutdownCallback();
+ }
+ return;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Sunset from current state"sv;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
"workers",
[this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); },
HttpVerb::kGet);
@@ -373,7 +442,7 @@ HttpComputeService::Impl::RegisterRoutes()
if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output);
ResponseCode != HttpResponseCode::OK)
{
- ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
+ ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
if (ResponseCode == HttpResponseCode::NotFound)
{
@@ -498,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording");
+ std::filesystem::path RecordingPath = m_BaseDir / "recording";
+
+ if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "error"
+ << "recording is already active";
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
- return HttpReq.WriteResponse(HttpResponseCode::OK);
+ CbObjectWriter Cbo;
+ Cbo << "path" << RecordingPath.string();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kPost);
@@ -514,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- m_ComputeService.StopRecording();
+ std::filesystem::path RecordingPath = m_BaseDir / "recording";
- return HttpReq.WriteResponse(HttpResponseCode::OK);
+ if (!m_ComputeService.StopRecording())
+ {
+ CbObjectWriter Cbo;
+ Cbo << "error"
+ << "no recording is active";
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "path" << RecordingPath.string();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kPost);
@@ -583,7 +672,7 @@ HttpComputeService::Impl::RegisterRoutes()
},
HttpVerb::kGet | HttpVerb::kPost);
- // Queue creation routes — these remain separate since local creates a plain queue
+ // Queue creation routes - these remain separate since local creates a plain queue
// while remote additionally generates an OID token for external access.
m_Router.RegisterRoute(
@@ -637,7 +726,7 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
}
- // Queue has since expired — clean up stale entries and fall through to create a new one
+ // Queue has since expired - clean up stale entries and fall through to create a new one
m_RemoteQueuesByToken.erase(Existing->Token);
m_RemoteQueuesByQueueId.erase(Existing->QueueId);
m_RemoteQueuesByTag.erase(It);
@@ -666,7 +755,7 @@ HttpComputeService::Impl::RegisterRoutes()
},
HttpVerb::kPost);
- // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens.
+ // Unified queue routes - {queueref} accepts both local integer IDs and remote OID tokens.
// ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution.
m_Router.RegisterRoute(
@@ -1016,7 +1105,7 @@ HttpComputeService::Impl::RegisterRoutes()
},
HttpVerb::kPost);
- // WebSocket upgrade endpoint — the handler logic lives in
+ // WebSocket upgrade endpoint - the handler logic lives in
// HttpComputeService::OnWebSocket* methods; this route merely
// satisfies the router so the upgrade request isn't rejected.
m_Router.RegisterRoute(
@@ -1027,11 +1116,12 @@ HttpComputeService::Impl::RegisterRoutes()
//////////////////////////////////////////////////////////////////////////
-HttpComputeService::HttpComputeService(CidStore& InCidStore,
+HttpComputeService::HttpComputeService(ChunkStore& InActionStore,
+ ChunkStore& InWorkerStore,
IHttpStatsService& StatsService,
const std::filesystem::path& BaseDir,
int32_t MaxConcurrentActions)
-: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions))
+: m_Impl(std::make_unique<Impl>(this, InActionStore, InWorkerStore, StatsService, BaseDir, MaxConcurrentActions))
{
}
@@ -1057,6 +1147,12 @@ HttpComputeService::GetActionCounts()
return m_Impl->m_ComputeService.GetActionCounts();
}
+void
+HttpComputeService::SetShutdownCallback(std::function<void()> Callback)
+{
+ m_Impl->m_ShutdownCallback = std::move(Callback);
+}
+
const char*
HttpComputeService::BaseUri() const
{
@@ -1145,7 +1241,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
{
if (OidMatcher(Capture))
{
- // Remote OID token — accessible from any client
+ // Remote OID token - accessible from any client
const Oid Token = Oid::FromHexString(Capture);
const int QueueId = ResolveQueueToken(Token);
@@ -1157,7 +1253,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
return QueueId;
}
- // Local integer queue ID — restricted to local machine requests
+ // Local integer queue ID - restricted to local machine requests
if (!HttpReq.IsLocalMachineRequest())
{
HttpReq.WriteResponse(HttpResponseCode::Forbidden);
@@ -1167,35 +1263,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
return ParseInt<int>(Capture).value_or(0);
}
-HttpComputeService::Impl::IngestStats
-HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package)
+bool
+HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment)
{
- IngestStats Stats;
+ const IoHash ClaimedHash = Attachment.GetHash();
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+ const IoHash HeaderHash = Buffer.DecodeRawHash();
+ if (HeaderHash != ClaimedHash)
+ {
+ ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ IoHashStream Hasher;
+
+ bool DecompressOk = Buffer.DecompressToStream(
+ 0,
+ Buffer.DecodeRawSize(),
+ [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool {
+ for (const SharedBuffer& Segment : Range.GetSegments())
+ {
+ Hasher.Append(Segment.GetView());
+ }
+ return true;
+ });
+
+ if (!DecompressOk)
+ {
+ ZEN_WARN("attachment {}: failed to decompress", ClaimedHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ const IoHash ActualHash = Hasher.GetHash();
+
+ if (ActualHash != ClaimedHash)
+ {
+ ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ return true;
+}
+
+bool
+HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats)
+{
for (const CbAttachment& Attachment : Package.GetAttachments())
{
ZEN_ASSERT(Attachment.IsCompressedBinary());
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
+ if (!ValidateAttachmentHash(HttpReq, Attachment))
+ {
+ return false;
+ }
- const uint64_t CompressedSize = DataView.GetCompressedSize();
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
- Stats.Bytes += CompressedSize;
- ++Stats.Count;
+ OutStats.Bytes += CompressedSize;
+ ++OutStats.Count;
- const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+ const ChunkStore::InsertResult InsertResult = m_ActionStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
if (InsertResult.New)
{
- Stats.NewBytes += CompressedSize;
- ++Stats.NewCount;
+ OutStats.NewBytes += CompressedSize;
+ ++OutStats.NewCount;
}
}
- return Stats;
+ return true;
}
bool
@@ -1204,7 +1346,7 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto
ActionObj.IterateAttachments([&](CbFieldView Field) {
const IoHash FileHash = Field.AsHash();
- if (!m_CidStore.ContainsChunk(FileHash))
+ if (!m_ActionStore.ContainsChunk(FileHash))
{
NeedList.push_back(FileHash);
}
@@ -1253,7 +1395,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
{
CbPackage Package = HttpReq.ReadPayloadPackage();
Body = Package.GetObject();
- Stats = IngestPackageAttachments(Package);
+ if (!IngestPackageAttachments(HttpReq, Package, Stats))
+ {
+ return; // validation failed, response already written
+ }
break;
}
@@ -1268,8 +1413,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
{
// --- Batch path ---
- // For CbObject payloads, check all attachments upfront before enqueuing anything
- if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ // Verify all action attachment references exist in the store
{
std::vector<IoHash> NeedList;
@@ -1345,7 +1489,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
// --- Single-action path: Body is the action itself ---
- if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
{
std::vector<IoHash> NeedList;
@@ -1453,7 +1596,7 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
CbPackage WorkerPackage;
WorkerPackage.SetObject(WorkerSpec);
- m_CidStore.FilterChunks(ChunkSet);
+ m_WorkerStore.FilterChunks(ChunkSet);
if (ChunkSet.IsEmpty())
{
@@ -1491,15 +1634,19 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
{
ZEN_ASSERT(Attachment.IsCompressedBinary());
+ if (!ValidateAttachmentHash(HttpReq, Attachment))
+ {
+ return;
+ }
+
const IoHash DataHash = Attachment.GetHash();
CompressedBuffer Buffer = Attachment.AsCompressedBinary();
- ZEN_UNUSED(DataHash);
TotalAttachmentBytes += Buffer.GetCompressedSize();
++AttachmentCount;
- const CidStore::InsertResult InsertResult =
- m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+ const ChunkStore::InsertResult InsertResult =
+ m_WorkerStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
if (InsertResult.New)
{
@@ -1537,9 +1684,9 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
//
void
-HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
- m_Impl->OnWebSocketOpen(std::move(Connection));
+ m_Impl->OnWebSocketOpen(std::move(Connection), RelativeUri);
}
void
@@ -1562,12 +1709,13 @@ HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotificati
//////////////////////////////////////////////////////////////////////////
//
-// Impl — WebSocket / observer
+// Impl - WebSocket / observer
//
void
-HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_INFO("compute WebSocket client connected");
m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
index 6cbe01e04..56eadcd57 100644
--- a/src/zencompute/httporchestrator.cpp
+++ b/src/zencompute/httporchestrator.cpp
@@ -7,6 +7,7 @@
# include <zencompute/orchestratorservice.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/logging.h>
+# include <zencore/session.h>
# include <zencore/string.h>
# include <zencore/system.h>
@@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn
return Ann.Id;
}
+static OrchestratorService::WorkerAnnotator
+MakeWorkerAnnotator(IProvisionerStateProvider* Prov)
+{
+ if (!Prov)
+ {
+ return {};
+ }
+ return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) {
+ AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId);
+ if (Status != AgentProvisioningStatus::Unknown)
+ {
+ const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active";
+ Cbo << "provisioner_status" << std::string_view(StatusStr);
+ }
+ };
+}
+
+bool
+HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId)
+{
+ std::string_view SessionStr = Data["coordinator_session"].AsString("");
+ if (SessionStr.empty())
+ {
+ return true; // backwards compatibility: accept announcements without a session
+ }
+ Oid Session = Oid::TryFromHexString(SessionStr);
+ if (Session == m_SessionId)
+ {
+ return true;
+ }
+ ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString());
+ return false;
+}
+
HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket))
, m_Hostname(GetMachineName())
{
+ m_SessionId = zen::GetSessionId();
+ ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString());
+
m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
@@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
[this](HttpRouterRequest& Req) {
CbObjectWriter Cbo;
Cbo << "hostname" << std::string_view(m_Hostname);
+ Cbo << "session_id" << m_SessionId.ToString();
Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kGet);
m_Router.RegisterRoute(
"provision",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kPost);
m_Router.RegisterRoute(
@@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
"characters and uri must start with http:// or https://");
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session");
+ }
+
m_Service->AnnounceWorker(Ann);
HttpReq.WriteResponse(HttpResponseCode::OK);
@@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
m_Router.RegisterRoute(
"agents",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kGet);
m_Router.RegisterRoute(
@@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
},
HttpVerb::kGet);
+ // Provisioner endpoints
+
+ m_Router.RegisterRoute(
+ "provisioner/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ Cbo << "name" << Prov->GetName();
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount();
+ Cbo << "active_cores" << Prov->GetActiveCoreCount();
+ Cbo << "agents" << Prov->GetAgentCount();
+ Cbo << "agents_draining" << Prov->GetDrainingAgentCount();
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "provisioner/target",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+ int32_t Cores = Data["target_cores"].AsInt32(-1);
+
+ ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false);
+
+ if (Cores < 0)
+ {
+ ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field");
+ }
+
+ IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire);
+ if (!Prov)
+ {
+ ZEN_WARN("provisioner/target: no provisioner configured");
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured");
+ }
+
+ ZEN_INFO("provisioner/target: setting target to {} cores", Cores);
+ Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores));
+
+ CbObjectWriter Cbo;
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
// Client tracking endpoints
m_Router.RegisterRoute(
@@ -375,7 +478,7 @@ HttpOrchestratorService::Shutdown()
m_PushThread.join();
}
- // Clean up worker WebSocket connections — collect IDs under lock, then
+ // Clean up worker WebSocket connections - collect IDs under lock, then
// notify the service outside the lock to avoid lock-order inversions.
std::vector<std::string> WorkerIds;
m_WorkerWsLock.WithExclusiveLock([&] {
@@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
}
}
+void
+HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+ m_Service->SetProvisionerStateProvider(Provider);
+}
+
//////////////////////////////////////////////////////////////////////////
//
// IWebSocketHandler
@@ -418,8 +528,9 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
# if ZEN_WITH_WEBSOCKETS
void
-HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
if (!m_PushEnabled.load())
{
return;
@@ -471,7 +582,7 @@ std::string
HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg)
{
// Workers send CbObject in native binary format over the WebSocket to
- // avoid the lossy CbObject↔JSON round-trip.
+ // avoid the lossy CbObject<->JSON round-trip.
CbObject Data = CbObject::MakeView(Msg.Payload.GetData());
if (!Data)
{
@@ -487,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms
return {};
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return {};
+ }
+
m_Service->AnnounceWorker(Ann);
return std::string(WorkerId);
}
@@ -562,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction()
}
// Build combined JSON with worker list, provisioning history, clients, and client history
- CbObject WorkerList = m_Service->GetWorkerList();
+ CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)));
CbObject History = m_Service->GetProvisioningHistory(50);
CbObject ClientList = m_Service->GetClientList();
CbObject ClientHistory = m_Service->GetClientHistory(50);
@@ -614,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction()
JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2));
}
+ // Emit provisioner stats if available
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ JsonBuilder.Append(
+ fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}"
+ ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}",
+ Prov->GetName(),
+ Prov->GetTargetCoreCount(),
+ Prov->GetEstimatedCoreCount(),
+ Prov->GetActiveCoreCount(),
+ Prov->GetAgentCount(),
+ Prov->GetDrainingAgentCount()));
+ }
+
JsonBuilder.Append("}");
std::string_view Json = JsonBuilder.ToView();
diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h
index 3b9642ac3..280d794e7 100644
--- a/src/zencompute/include/zencompute/cloudmetadata.h
+++ b/src/zencompute/include/zencompute/cloudmetadata.h
@@ -64,7 +64,7 @@ public:
explicit CloudMetadata(std::filesystem::path DataDir);
/** Synchronously probes cloud providers at the given IMDS endpoint.
- * Intended for testing — allows redirecting all IMDS queries to a local
+ * Intended for testing - allows redirecting all IMDS queries to a local
* mock HTTP server instead of the real 169.254.169.254 endpoint.
*/
CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint);
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
index 1ca78738a..97de4321a 100644
--- a/src/zencompute/include/zencompute/computeservice.h
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -167,6 +167,7 @@ public:
// Action runners
void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
+ void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName);
// Action submission
@@ -278,7 +279,7 @@ public:
// sized to match RunnerAction::State::_Count but we can't use the enum here
// for dependency reasons, so just use a fixed size array and static assert in
// the implementation file
- uint64_t Timestamps[9] = {};
+ uint64_t Timestamps[10] = {};
};
[[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100);
@@ -304,8 +305,9 @@ public:
// Recording
- void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
- void StopRecording();
+ bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
+ bool StopRecording();
+ bool IsRecording() const;
private:
void PostUpdate(RunnerAction* Action);
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
index b58e73a0d..32f54f293 100644
--- a/src/zencompute/include/zencompute/httpcomputeservice.h
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -15,7 +15,7 @@
# include <memory>
namespace zen {
-class CidStore;
+class ChunkStore;
}
namespace zen::compute {
@@ -26,7 +26,8 @@ namespace zen::compute {
class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver
{
public:
- HttpComputeService(CidStore& InCidStore,
+ HttpComputeService(ChunkStore& InActionStore,
+ ChunkStore& InWorkerStore,
IHttpStatsService& StatsService,
const std::filesystem::path& BaseDir,
int32_t MaxConcurrentActions = 0);
@@ -34,6 +35,10 @@ public:
void Shutdown();
+ /** Set a callback to be invoked when the session/sunset endpoint is hit.
+ * Typically wired to HttpServer::RequestExit() to shut down the process. */
+ void SetShutdownCallback(std::function<void()> Callback);
+
[[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts();
const char* BaseUri() const override;
@@ -45,7 +50,7 @@ public:
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h
index da5c5dfc3..4e4f5f0f8 100644
--- a/src/zencompute/include/zencompute/httporchestrator.h
+++ b/src/zencompute/include/zencompute/httporchestrator.h
@@ -2,10 +2,12 @@
#pragma once
+#include <zencompute/provisionerstate.h>
#include <zencompute/zencompute.h>
#include <zencore/logging.h>
#include <zencore/thread.h>
+#include <zencore/uid.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/websocket.h>
@@ -65,12 +67,22 @@ public:
*/
void Shutdown();
+ /** Return the session ID generated at construction time. Provisioners
+ * pass this to spawned workers so the orchestrator can reject stale
+ * announcements from previous sessions. */
+ Oid GetSessionId() const { return m_SessionId; }
+
+ /** Register a provisioner whose target core count can be read and changed
+ * via the orchestrator HTTP API and dashboard. Caller retains ownership;
+ * the provider must outlive this service. */
+ void SetProvisionerStateProvider(IProvisionerStateProvider* Provider);
+
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
// IWebSocketHandler
#if ZEN_WITH_WEBSOCKETS
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
#endif
@@ -81,6 +93,11 @@ private:
std::unique_ptr<OrchestratorService> m_Service;
std::string m_Hostname;
+ Oid m_SessionId;
+ bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId);
+
+ std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr};
+
// WebSocket push
#if ZEN_WITH_WEBSOCKETS
@@ -91,9 +108,9 @@ private:
Event m_PushEvent;
void PushThreadFunction();
- // Worker WebSocket connections (worker→orchestrator persistent links)
+ // Worker WebSocket connections (worker->orchestrator persistent links)
RwLock m_WorkerWsLock;
- std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID
+ std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr -> worker ID
std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg);
#endif
};
diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h
index 704306913..6074240b9 100644
--- a/src/zencompute/include/zencompute/mockimds.h
+++ b/src/zencompute/include/zencompute/mockimds.h
@@ -1,5 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-// Moved to zenutil — this header is kept for backward compatibility.
+// Moved to zenutil - this header is kept for backward compatibility.
#pragma once
diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h
index 071e902b3..2c49e22df 100644
--- a/src/zencompute/include/zencompute/orchestratorservice.h
+++ b/src/zencompute/include/zencompute/orchestratorservice.h
@@ -6,7 +6,10 @@
#if ZEN_WITH_COMPUTE_SERVICES
+# include <zencompute/provisionerstate.h>
# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/logbase.h>
# include <zencore/thread.h>
# include <zencore/timer.h>
# include <zencore/uid.h>
@@ -88,9 +91,16 @@ public:
std::string Hostname;
};
- CbObject GetWorkerList();
+ /** Per-worker callback invoked during GetWorkerList serialization.
+ * The callback receives the worker ID and a CbObjectWriter positioned
+ * inside the worker's object, allowing the caller to append extra fields. */
+ using WorkerAnnotator = std::function<void(std::string_view WorkerId, CbObjectWriter& Cbo)>;
+
+ CbObject GetWorkerList(const WorkerAnnotator& Annotate = {});
void AnnounceWorker(const WorkerAnnouncement& Announcement);
+ void SetProvisionerStateProvider(IProvisionerStateProvider* Provider);
+
bool IsWorkerWebSocketEnabled() const;
void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected);
@@ -164,7 +174,12 @@ private:
void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname);
- bool m_EnableWorkerWebSocket = false;
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log{"compute.orchestrator"};
+ bool m_EnableWorkerWebSocket = false;
+
+ std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr};
std::thread m_ProbeThread;
std::atomic<bool> m_ProbeThreadEnabled{true};
diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h
new file mode 100644
index 000000000..e9af8a635
--- /dev/null
+++ b/src/zencompute/include/zencompute/provisionerstate.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <cstdint>
+#include <string_view>
+
+namespace zen::compute {
+
+/** Per-agent provisioning status as seen by the provisioner. */
+enum class AgentProvisioningStatus
+{
+ Unknown, ///< Not known to the provisioner
+ Active, ///< Running and allocated
+ Draining, ///< Being gracefully deprovisioned
+};
+
+/** Abstract interface for querying and controlling a provisioner from the HTTP layer.
+ * This decouples the orchestrator service from specific provisioner implementations. */
+class IProvisionerStateProvider
+{
+public:
+ virtual ~IProvisionerStateProvider() = default;
+
+ virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad"
+ virtual uint32_t GetTargetCoreCount() const = 0;
+ virtual uint32_t GetEstimatedCoreCount() const = 0;
+ virtual uint32_t GetActiveCoreCount() const = 0;
+ virtual uint32_t GetAgentCount() const = 0;
+ virtual uint32_t GetDrainingAgentCount() const { return 0; }
+ virtual void SetTargetCoreCount(uint32_t Count) = 0;
+
+ /** Return the provisioning status for a worker by its orchestrator ID
+ * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */
+ virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; }
+};
+
+} // namespace zen::compute
diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp
index 9ea695305..68199ab3c 100644
--- a/src/zencompute/orchestratorservice.cpp
+++ b/src/zencompute/orchestratorservice.cpp
@@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService()
}
CbObject
-OrchestratorService::GetWorkerList()
+OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate)
{
ZEN_TRACE_CPU("OrchestratorService::GetWorkerList");
CbObjectWriter Cbo;
@@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList()
Cbo << "ws_connected" << true;
}
Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs();
+ if (Annotate)
+ {
+ Annotate(WorkerId, Cbo);
+ }
Cbo.EndObject();
}
});
@@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann)
}
}
+void
+OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+}
+
bool
OrchestratorService::IsWorkerWebSocketEnabled() const
{
@@ -170,11 +180,11 @@ OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool
if (Connected)
{
- ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId);
+ ZEN_INFO("worker {} WebSocket connected - marking reachable", WorkerId);
}
else
{
- ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId);
+ ZEN_WARN("worker {} WebSocket disconnected - marking unreachable", WorkerId);
}
});
@@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction()
continue;
}
+ // Check if the provisioner knows this worker is draining - if so,
+ // unreachability is expected and should not be logged as a warning.
+ bool IsDraining = false;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining;
+ }
+
ReachableState NewState = ReachableState::Unreachable;
try
@@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction()
}
catch (const std::exception& Ex)
{
- ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ if (!IsDraining)
+ {
+ ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ }
}
ReachableState PrevState = ReachableState::Unknown;
@@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction()
{
ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri);
}
+ else if (IsDraining)
+ {
+ ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri);
+ }
else if (PrevState == ReachableState::Reachable)
{
ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri);
diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h
new file mode 100644
index 000000000..d50ad4a2a
--- /dev/null
+++ b/src/zencompute/pathvalidation.h
@@ -0,0 +1,118 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/except_fmt.h>
+#include <zencore/string.h>
+
+#include <filesystem>
+#include <string_view>
+
+namespace zen::compute {
+
+// Validate that a single path component contains only characters that are valid
+// file/directory names on all supported platforms. Uses Windows rules as the most
+// restrictive superset, since packages may be built on one platform and consumed
+// on another.
+inline void
+ValidatePathComponent(std::string_view Component, std::string_view FullPath)
+{
+ // Reject control characters (0x00-0x1F) and characters forbidden on Windows
+ for (char Ch : Component)
+ {
+ if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' ||
+ Ch == '*')
+ {
+ throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath);
+ }
+ }
+
+ // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion)
+ if (Component.empty() || Component.back() == '.' || Component.back() == ' ')
+ {
+ throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath);
+ }
+
+ // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9)
+ // These are reserved with or without an extension (e.g. "CON.txt" is still reserved).
+ std::string_view Stem = Component.substr(0, Component.find('.'));
+
+ static constexpr std::string_view ReservedNames[] = {
+ "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7",
+ "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
+ };
+
+ for (std::string_view Reserved : ReservedNames)
+ {
+ if (zen::StrCaseCompare(Stem, Reserved) == 0)
+ {
+ throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved);
+ }
+ }
+}
+
+// Validate that a path extracted from a package is a safe relative path.
+// Rejects absolute paths, ".." components, and invalid platform filenames.
+inline void
+ValidateSandboxRelativePath(std::string_view Name)
+{
+ if (Name.empty())
+ {
+ throw zen::invalid_argument("path traversal detected: empty path name");
+ }
+
+ std::filesystem::path Parsed(Name);
+
+ if (Parsed.is_absolute())
+ {
+ throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name);
+ }
+
+ for (const auto& Component : Parsed)
+ {
+ std::string ComponentStr = Component.string();
+
+ if (ComponentStr == "..")
+ {
+ throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name);
+ }
+
+ // Skip "." (current directory) - harmless in relative paths
+ if (ComponentStr != ".")
+ {
+ ValidatePathComponent(ComponentStr, Name);
+ }
+ }
+}
+
+// Validate all path entries in a worker description CbObject.
+// Checks path, executables[].name, dirs[], and files[].name fields.
+// Throws an exception if any invalid paths are found.
+inline void
+ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription)
+{
+ using namespace std::literals;
+
+ if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue())
+ {
+ ValidateSandboxRelativePath(PathField.AsString());
+ }
+
+ for (auto& It : WorkerDescription["executables"sv])
+ {
+ ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString());
+ }
+
+ for (auto& It : WorkerDescription["dirs"sv])
+ {
+ ValidateSandboxRelativePath(It.AsString());
+ }
+
+ for (auto& It : WorkerDescription["files"sv])
+ {
+ ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString());
+ }
+}
+
+} // namespace zen::compute
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
index 4f116e7d8..34bf065b4 100644
--- a/src/zencompute/runners/functionrunner.cpp
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -6,9 +6,15 @@
# include <zencore/compactbinary.h>
# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/logging.h>
+# include <zencore/string.h>
+# include <zencore/timer.h>
# include <zencore/trace.h>
+# include <zencore/workthreadpool.h>
# include <fmt/format.h>
+# include <future>
# include <vector>
namespace zen::compute {
@@ -118,23 +124,34 @@ std::vector<SubmitResult>
BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
{
ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions");
- RwLock::SharedLockScope _(m_RunnersLock);
- const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+ // Snapshot runners and query capacity under the lock, then release
+ // before submitting - HTTP submissions to remote runners can take
+ // hundreds of milliseconds and we must not hold m_RunnersLock during I/O.
- if (RunnerCount == 0)
- {
- return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
- }
+ std::vector<Ref<FunctionRunner>> Runners;
+ std::vector<size_t> Capacities;
+ std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions;
+ size_t TotalCapacity = 0;
- // Query capacity per runner and compute total
- std::vector<size_t> Capacities(RunnerCount);
- size_t TotalCapacity = 0;
+ m_RunnersLock.WithSharedLock([&] {
+ const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+ Runners.assign(m_Runners.begin(), m_Runners.end());
+ Capacities.resize(RunnerCount);
+ PerRunnerActions.resize(RunnerCount);
- for (int i = 0; i < RunnerCount; ++i)
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ Capacities[i] = Runners[i]->QueryCapacity();
+ TotalCapacity += Capacities[i];
+ }
+ });
+
+ const int RunnerCount = gsl::narrow<int>(Runners.size());
+
+ if (RunnerCount == 0)
{
- Capacities[i] = m_Runners[i]->QueryCapacity();
- TotalCapacity += Capacities[i];
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
}
if (TotalCapacity == 0)
@@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
// Distribute actions across runners proportionally to their available capacity
- std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount);
- std::vector<size_t> ActionRunnerIndex(Actions.size());
- size_t ActionIdx = 0;
+ std::vector<size_t> ActionRunnerIndex(Actions.size());
+ size_t ActionIdx = 0;
for (int i = 0; i < RunnerCount; ++i)
{
@@ -164,8 +180,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Assign any remaining actions to runners with capacity (round-robin)
- for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount)
+ // Assign any remaining actions to runners with capacity (round-robin).
+ // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept.
+ for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount)
{
if (Capacities[i] > PerRunnerActions[i].size())
{
@@ -175,22 +192,83 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Submit batches per runner
+ // Submit batches per runner - in parallel when a worker pool is available
+
std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount);
+ int ActiveRunnerCount = 0;
for (int i = 0; i < RunnerCount; ++i)
{
if (!PerRunnerActions[i].empty())
{
- PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]);
+ ++ActiveRunnerCount;
+ }
+ }
+
+ static constexpr uint64_t SubmitWarnThresholdMs = 500;
+
+ auto SubmitToRunner = [&](int RunnerIndex) {
+ auto& Runner = Runners[RunnerIndex];
+ Runner->m_LastSubmitStats.Reset();
+
+ Stopwatch Timer;
+
+ PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]);
+
+ uint64_t ElapsedMs = Timer.GetElapsedTimeMs();
+ if (ElapsedMs >= SubmitWarnThresholdMs)
+ {
+ size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed);
+ uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed);
+
+ ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms",
+ PerRunnerActions[RunnerIndex].size(),
+ Attachments,
+ NiceBytes(AttachmentBytes),
+ Runner->GetDisplayName(),
+ ElapsedMs);
+ }
+ };
+
+ if (m_WorkerPool && ActiveRunnerCount > 1)
+ {
+ std::vector<std::future<void>> Futures(RunnerCount);
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ std::packaged_task<void()> Task([&SubmitToRunner, i]() { SubmitToRunner(i); });
+
+ Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog);
+ }
+ }
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (Futures[i].valid())
+ {
+ Futures[i].get();
+ }
+ }
+ }
+ else
+ {
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ SubmitToRunner(i);
+ }
}
}
- // Reassemble results in original action order
- std::vector<SubmitResult> Results(Actions.size());
+ // Reassemble results in original action order.
+ // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity).
+ std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"});
std::vector<size_t> PerRunnerIdx(RunnerCount, 0);
- for (size_t i = 0; i < Actions.size(); ++i)
+ for (size_t i = 0; i < ActionIdx; ++i)
{
size_t RunnerIdx = ActionRunnerIndex[i];
size_t Idx = PerRunnerIdx[RunnerIdx]++;
@@ -307,10 +385,11 @@ RunnerAction::RetractAction()
bool
RunnerAction::ResetActionStateToPending()
{
- // Only allow reset from Failed, Abandoned, or Retracted states
+ // Only allow reset from Failed, Abandoned, Rejected, or Retracted states
State CurrentState = m_ActionState.load();
- if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted)
+ if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected &&
+ CurrentState != State::Retracted)
{
return false;
}
@@ -331,11 +410,12 @@ RunnerAction::ResetActionStateToPending()
// Clear execution fields
ExecutionLocation.clear();
+ FailureReason.clear();
CpuUsagePercent.store(-1.0f, std::memory_order_relaxed);
CpuSeconds.store(0.0f, std::memory_order_relaxed);
- // Increment retry count (skip for Retracted — nothing failed)
- if (CurrentState != State::Retracted)
+ // Increment retry count (skip for Retracted/Rejected - nothing failed)
+ if (CurrentState != State::Retracted && CurrentState != State::Rejected)
{
RetryCount.fetch_add(1, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
index 56c3f3af0..371a60b7a 100644
--- a/src/zencompute/runners/functionrunner.h
+++ b/src/zencompute/runners/functionrunner.h
@@ -10,6 +10,10 @@
# include <filesystem>
# include <vector>
+namespace zen {
+class WorkerThreadPool;
+}
+
namespace zen::compute {
struct SubmitResult
@@ -37,6 +41,22 @@ public:
[[nodiscard]] virtual bool IsHealthy() = 0;
[[nodiscard]] virtual size_t QueryCapacity();
[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+ [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; }
+
+ // Accumulated stats from the most recent SubmitActions call.
+ // Reset before each call, populated by the runner implementation.
+ struct SubmitStats
+ {
+ std::atomic<size_t> TotalAttachments{0};
+ std::atomic<uint64_t> TotalAttachmentBytes{0};
+
+ void Reset()
+ {
+ TotalAttachments.store(0, std::memory_order_relaxed);
+ TotalAttachmentBytes.store(0, std::memory_order_relaxed);
+ }
+ };
+ SubmitStats m_LastSubmitStats;
// Best-effort cancellation of a specific in-flight action. Returns true if the
// cancellation signal was successfully sent. The action will transition to Cancelled
@@ -68,6 +88,8 @@ public:
bool CancelAction(int ActionLsn);
void CancelRemoteQueue(int QueueId);
+ void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; }
+
size_t GetRunnerCount()
{
return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); });
@@ -79,6 +101,7 @@ protected:
RwLock m_RunnersLock;
std::vector<Ref<FunctionRunner>> m_Runners;
std::atomic<int> m_NextSubmitIndex{0};
+ WorkerThreadPool* m_WorkerPool = nullptr;
};
/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal.
@@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted
CbObject ActionObj;
int Priority = 0;
std::string ExecutionLocation; // "local" or remote hostname
+ std::string FailureReason; // human-readable reason when action fails (empty on success)
// CPU usage and total CPU time of the running process, sampled periodically by the local runner.
// CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage.
@@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted
Completed, // Finished successfully with results available
Failed, // Execution failed (transient error, eligible for retry)
Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon)
+ Rejected, // Runner declined (e.g. at capacity) - rescheduled without retry cost
Cancelled, // Intentional user cancellation (never retried)
Retracted, // Pulled back for rescheduling on a different runner (no retry cost)
_Count
@@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted
return "Failed";
case State::Abandoned:
return "Abandoned";
+ case State::Rejected:
+ return "Rejected";
case State::Cancelled:
return "Cancelled";
case State::Retracted:
diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp
index e79a6c90f..be4274823 100644
--- a/src/zencompute/runners/linuxrunner.cpp
+++ b/src/zencompute/runners/linuxrunner.cpp
@@ -195,7 +195,7 @@ namespace {
WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno);
}
- // /lib64 (optional — not all distros have it)
+ // /lib64 (optional - not all distros have it)
{
struct stat St;
if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode))
@@ -208,7 +208,7 @@ namespace {
}
}
- // /etc (required — for resolv.conf, ld.so.cache, etc.)
+ // /etc (required - for resolv.conf, ld.so.cache, etc.)
if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0)
{
WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno);
@@ -218,7 +218,7 @@ namespace {
WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno);
}
- // /worker — bind-mount worker directory (contains the executable)
+ // /worker - bind-mount worker directory (contains the executable)
if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0)
{
WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno);
@@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver,
{
ZEN_INFO("namespace sandboxing enabled for child processes");
}
+
+ StartMonitorThread();
}
SubmitResult
@@ -428,11 +430,12 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process - lower priority so workers don't starve the main server
+ nice(5);
if (m_Sandboxed)
{
- // Close read end of error pipe — child only writes
+ // Close read end of error pipe - child only writes
close(ErrorPipe[0]);
SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]);
@@ -459,7 +462,7 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (m_Sandboxed)
{
- // Close write end of error pipe — parent only reads
+ // Close write end of error pipe - parent only reads
close(ErrorPipe[1]);
// Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC
@@ -479,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
// Clean up the sandbox in the background
m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
- ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+ Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf);
+ ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason);
Action->SetActionState(RunnerAction::State::Failed);
return SubmitResult{.IsAccepted = false};
@@ -675,7 +679,7 @@ ReadProcStatCpuTicks(pid_t Pid)
Buf[Len] = '\0';
- // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens
+ // Skip past "pid (name) " - find last ')' to handle names containing spaces or parens
const char* P = strrchr(Buf, ')');
if (!P)
{
@@ -705,7 +709,7 @@ LinuxProcessRunner::SampleProcessCpu(RunningAction& Running)
if (CurrentOsTicks == 0)
{
- // Process gone or /proc entry unreadable — record timestamp without updating usage
+ // Process gone or /proc entry unreadable - record timestamp without updating usage
Running.LastCpuSampleTicks = NowTicks;
Running.LastCpuOsTicks = 0;
return;
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
index b61e0a46f..259965e23 100644
--- a/src/zencompute/runners/localrunner.cpp
+++ b/src/zencompute/runners/localrunner.cpp
@@ -4,6 +4,8 @@
#if ZEN_WITH_COMPUTE_SERVICES
+# include "pathvalidation.h"
+
# include <zencore/compactbinary.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinarypackage.h>
@@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver,
ZEN_INFO("Cleanup complete");
}
- m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
-
# if ZEN_PLATFORM_WINDOWS
// Suppress any error dialogs caused by missing dependencies
UINT OldMode = ::SetErrorMode(0);
@@ -337,7 +337,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action)
SubmitResult
LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action)
{
- // Base class is not directly usable — platform subclasses override this
+ // Base class is not directly usable - platform subclasses override this
ZEN_UNUSED(Action);
return SubmitResult{.IsAccepted = false};
}
@@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker)
std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId);
- if (!std::filesystem::exists(WorkerDir))
+ // worker.zcb is written as the last step of ManifestWorker, so its presence
+ // indicates a complete manifest. If the directory exists but the marker is
+ // missing, a previous manifest was interrupted and we need to start over.
+ bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb");
+
+ if (NeedsManifest)
{
_.ReleaseNow();
RwLock::ExclusiveLockScope $(m_WorkerLock);
- if (!std::filesystem::exists(WorkerDir))
+ if (!std::filesystem::exists(WorkerDir / "worker.zcb"))
{
+ std::error_code Ec;
+ std::filesystem::remove_all(WorkerDir, Ec);
ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {});
}
}
@@ -382,6 +389,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP
const IoHash ChunkHash = FileEntry["hash"sv].AsHash();
const uint64_t Size = FileEntry["size"sv].AsUInt64();
+ ValidateSandboxRelativePath(Name);
+
CompressedBuffer Compressed;
if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash))
@@ -457,7 +466,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
for (auto& It : WorkerDescription["dirs"sv])
{
- std::string_view Name = It.AsString();
+ std::string_view Name = It.AsString();
+ ValidateSandboxRelativePath(Name);
std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()};
// Validate dir path stays within sandbox
@@ -482,6 +492,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
}
WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer());
+
+ ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath);
}
CbPackage
@@ -540,6 +552,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath)
}
void
+LocalProcessRunner::StartMonitorThread()
+{
+ m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
+}
+
+void
LocalProcessRunner::MonitorThreadFunction()
{
SetCurrentThreadName("LocalProcessRunner_Monitor");
@@ -602,7 +620,7 @@ LocalProcessRunner::MonitorThreadFunction()
void
LocalProcessRunner::CancelRunningActions()
{
- // Base class is not directly usable — platform subclasses override this
+ // Base class is not directly usable - platform subclasses override this
}
void
@@ -662,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& Com
}
catch (std::exception& Ex)
{
- ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what());
+ Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what());
+ ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason);
}
}
+ else
+ {
+ Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode);
+ ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason);
+ }
// Failed - clean up the sandbox in the background.
diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h
index b8cff6826..d6589db43 100644
--- a/src/zencompute/runners/localrunner.h
+++ b/src/zencompute/runners/localrunner.h
@@ -67,6 +67,7 @@ protected:
{
Ref<RunnerAction> Action;
void* ProcessHandle = nullptr;
+ int Pid = 0;
int ExitCode = 0;
std::filesystem::path SandboxPath;
@@ -83,8 +84,6 @@ protected:
std::filesystem::path m_SandboxPath;
int32_t m_MaxRunningActions = 64; // arbitrary limit for testing
- // if used in conjuction with m_ResultsLock, this lock must be taken *after*
- // m_ResultsLock to avoid deadlocks
RwLock m_RunningLock;
std::unordered_map<int, Ref<RunningAction>> m_RunningMap;
@@ -95,6 +94,7 @@ protected:
std::thread m_MonitorThread;
std::atomic<bool> m_MonitorThreadEnabled{true};
Event m_MonitorThreadEvent;
+ void StartMonitorThread();
void MonitorThreadFunction();
virtual void SweepRunningActions();
virtual void CancelRunningActions();
diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp
index 5cec90699..ab24d4672 100644
--- a/src/zencompute/runners/macrunner.cpp
+++ b/src/zencompute/runners/macrunner.cpp
@@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver,
{
ZEN_INFO("Seatbelt sandboxing enabled for child processes");
}
+
+ StartMonitorThread();
}
SubmitResult
@@ -209,18 +211,19 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process - lower priority so workers don't starve the main server
+ nice(5);
if (m_Sandboxed)
{
- // Close read end of error pipe — child only writes
+ // Close read end of error pipe - child only writes
close(ErrorPipe[0]);
// Apply Seatbelt sandbox profile
char* ErrorBuf = nullptr;
if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0)
{
- // sandbox_init failed — write error to pipe and exit
+ // sandbox_init failed - write error to pipe and exit
if (ErrorBuf)
{
WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0);
@@ -259,7 +262,7 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (m_Sandboxed)
{
- // Close write end of error pipe — parent only reads
+ // Close write end of error pipe - parent only reads
close(ErrorPipe[1]);
// Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC
@@ -279,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
// Clean up the sandbox in the background
m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
- ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+ Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf);
+ ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason);
Action->SetActionState(RunnerAction::State::Failed);
return SubmitResult{.IsAccepted = false};
@@ -467,7 +471,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running)
const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system;
const uint64_t NowTicks = GetHifreqTimerValue();
- // Cumulative CPU seconds (absolute, available from first sample): ns → seconds
+ // Cumulative CPU seconds (absolute, available from first sample): ns -> seconds
Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed);
if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
@@ -476,7 +480,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running)
if (ElapsedMs > 0)
{
const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
- // ns → ms: divide by 1,000,000; then as percent of elapsed ms
+ // ns -> ms: divide by 1,000,000; then as percent of elapsed ms
const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0);
Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp
new file mode 100644
index 000000000..a4f586852
--- /dev/null
+++ b/src/zencompute/runners/managedrunner.cpp
@@ -0,0 +1,279 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "managedrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/scopeguard.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio/io_context.hpp>
+# include <asio/executor_work_guard.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
+, m_IoContext(std::make_unique<asio::io_context>())
+, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext))
+{
+ m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers");
+
+ // Run the io_context on a small thread pool so that exit callbacks and
+ // metrics sampling are dispatched without blocking each other.
+ for (int i = 0; i < kIoThreadCount; ++i)
+ {
+ m_IoThreads.emplace_back([this, i] {
+ SetCurrentThreadName(fmt::format("mrunner_{}", i));
+
+ // work_guard keeps run() alive even when there is no pending work yet
+ auto WorkGuard = asio::make_work_guard(*m_IoContext);
+
+ m_IoContext->run();
+ });
+ }
+}
+
+ManagedProcessRunner::~ManagedProcessRunner()
+{
+ try
+ {
+ Shutdown();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what());
+ }
+}
+
+void
+ManagedProcessRunner::Shutdown()
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown");
+ m_AcceptNewActions = false;
+
+ CancelRunningActions();
+
+ // Tear down the SubprocessManager before stopping the io_context so that
+ // any in-flight callbacks are drained cleanly.
+ if (m_SubprocessManager)
+ {
+ m_SubprocessManager->DestroyGroup("compute-workers");
+ m_ProcessGroup = nullptr;
+ m_SubprocessManager.reset();
+ }
+
+ if (m_IoContext)
+ {
+ m_IoContext->stop();
+ }
+
+ for (std::thread& Thread : m_IoThreads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+ m_IoThreads.clear();
+}
+
+SubmitResult
+ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ // Parse environment variables from worker descriptor ("KEY=VALUE" strings)
+ // into the key-value pairs expected by CreateProcOptions.
+ std::vector<std::pair<std::string, std::string>> EnvPairs;
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ std::string_view Str = It.AsString();
+ size_t Eq = Str.find('=');
+ if (Eq != std::string_view::npos)
+ {
+ EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1)));
+ }
+ }
+
+ // Build command line
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred();
+
+ std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string());
+
+ ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath);
+
+ CreateProcOptions Options;
+ Options.WorkingDirectory = &Prepared->SandboxPath;
+ Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority;
+ Options.Environment = std::move(EnvPairs);
+
+ const int32_t ActionLsn = Prepared->ActionLsn;
+
+ ManagedProcess* Proc = nullptr;
+
+ try
+ {
+ Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) {
+ OnProcessExit(ActionLsn, ExitCode);
+ });
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what());
+ m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath));
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ {
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = static_cast<void*>(Proc);
+ NewAction->Pid = Proc->Pid();
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RunningMap[ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid());
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit");
+
+ Ref<RunningAction> Running;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It != m_RunningMap.end())
+ {
+ Running = std::move(It->second);
+ m_RunningMap.erase(It);
+ }
+ });
+
+ if (!Running)
+ {
+ return;
+ }
+
+ ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode);
+
+ Running->ExitCode = ExitCode;
+
+ // Capture final CPU metrics from the managed process before it is removed.
+ auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle);
+ if (Proc)
+ {
+ ProcessMetrics Metrics = Proc->GetLatestMetrics();
+ float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs);
+ Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed);
+
+ float CpuPct = Proc->GetCpuUsagePercent();
+ if (CpuPct >= 0.0f)
+ {
+ Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
+ }
+ }
+
+ Running->ProcessHandle = nullptr;
+
+ std::vector<Ref<RunningAction>> CompletedActions;
+ CompletedActions.push_back(std::move(Running));
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+ManagedProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions");
+
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling {} running actions via process group", RunningMap.size());
+
+ Stopwatch Timer;
+
+ // Kill all processes in the group atomically (TerminateJobObject on Windows,
+ // SIGTERM+SIGKILL on POSIX).
+ if (m_ProcessGroup)
+ {
+ m_ProcessGroup->KillAll();
+ }
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+bool
+ManagedProcessRunner::CancelAction(int ActionLsn)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction");
+
+ ManagedProcess* Proc = nullptr;
+
+ m_RunningLock.WithSharedLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr)
+ {
+ Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle);
+ }
+ });
+
+ if (!Proc)
+ {
+ return false;
+ }
+
+ // Terminate the process. The exit callback will handle the rest
+ // (remove from running map, gather outputs or mark failed).
+ Proc->Terminate(222);
+
+ ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn);
+ return true;
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h
new file mode 100644
index 000000000..21a44d43c
--- /dev/null
+++ b/src/zencompute/runners/managedrunner.h
@@ -0,0 +1,64 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zenutil/process/subprocessmanager.h>
+
+# include <memory>
+
+namespace asio {
+class io_context;
+}
+
+namespace zen::compute {
+
+/** Cross-platform process runner backed by SubprocessManager.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and shared action preparation. Replaces the polling-based
+ monitor thread with async exit callbacks driven by SubprocessManager, and
+ delegates CPU/memory metrics sampling to the manager's built-in round-robin
+ sampler.
+
+ A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is
+ used for bulk cancellation on shutdown.
+
+ This runner does not perform any platform-specific sandboxing (AppContainer,
+ namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative
+ to the platform-specific runners for non-sandboxed workloads.
+ */
+class ManagedProcessRunner : public LocalProcessRunner
+{
+public:
+ ManagedProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions = 0);
+ ~ManagedProcessRunner();
+
+ void Shutdown() override;
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void CancelRunningActions() override;
+ bool CancelAction(int ActionLsn) override;
+ [[nodiscard]] bool IsHealthy() override { return true; }
+
+private:
+ static constexpr int kIoThreadCount = 4;
+
+ // Exit callback posted on an io_context thread.
+ void OnProcessExit(int ActionLsn, int ExitCode);
+
+ std::unique_ptr<asio::io_context> m_IoContext;
+ std::unique_ptr<SubprocessManager> m_SubprocessManager;
+ ProcessGroup* m_ProcessGroup = nullptr;
+ std::vector<std::thread> m_IoThreads;
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp
index ce6a81173..08f381b7f 100644
--- a/src/zencompute/runners/remotehttprunner.cpp
+++ b/src/zencompute/runners/remotehttprunner.cpp
@@ -20,6 +20,7 @@
# include <zenstore/cidstore.h>
# include <span>
+# include <unordered_set>
//////////////////////////////////////////////////////////////////////////
@@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
, m_ChunkResolver{InChunkResolver}
, m_WorkerPool{InWorkerPool}
, m_HostName{HostName}
+, m_DisplayName{HostName}
, m_BaseUrl{fmt::format("{}/compute", HostName)}
, m_Http(m_BaseUrl)
, m_InstanceId(Oid::NewOid())
@@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
}
+void
+RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname)
+{
+ if (!Hostname.empty())
+ {
+ m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname);
+ }
+}
+
RemoteHttpRunner::~RemoteHttpRunner()
{
Shutdown();
@@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown()
for (auto& [RemoteLsn, HttpAction] : Remaining)
{
ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn);
+ HttpAction.Action->FailureReason = "remote runner shutdown";
HttpAction.Action->SetActionState(RunnerAction::State::Failed);
}
}
@@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity()
return 0;
}
- // Estimate how much more work we're ready to accept
+ // Estimate how much more work we're ready to accept.
+ // Include actions currently being submitted over HTTP so we don't
+ // keep queueing new submissions while previous ones are still in flight.
RwLock::SharedLockScope _{m_RunningLock};
- size_t RunningCount = m_RemoteRunningMap.size();
+ size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed);
if (RunningCount >= size_t(m_MaxRunningActions))
{
@@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
{
ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions");
+ m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed);
+ auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); });
+
if (Actions.size() <= 1)
{
std::vector<SubmitResult> Results;
@@ -246,7 +263,7 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
// Collect distinct QueueIds and ensure remote queues exist once per queue
- std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero)
+ std::unordered_map<int, Oid> QueueTokens; // QueueId -> remote token (0 stays as Zero)
for (const Ref<RunnerAction>& Action : Actions)
{
@@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
}
}
- // Enqueue job. If the remote returns FailedDependency (424), it means it
- // cannot resolve the worker/function — re-register the worker and retry once.
+ // Submit the action to the remote. In eager-attach mode we build a
+ // CbPackage with all referenced attachments upfront to avoid the 404
+ // round-trip. In the default mode we POST the bare object first and
+ // only upload missing attachments if the remote requests them.
+ //
+ // In both modes, FailedDependency (424) triggers a worker re-register
+ // and a single retry.
CbObject Result;
HttpClient::Response WorkResponse;
HttpResponseCode WorkResponseCode{};
- for (int Attempt = 0; Attempt < 2; ++Attempt)
- {
- WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
- WorkResponseCode = WorkResponse.StatusCode;
-
- if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
- {
- ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying",
- m_Http.GetBaseUri(),
- ActionId);
-
- (void)RegisterWorker(Action->Worker.Descriptor);
- }
- else
- {
- break;
- }
- }
-
- if (WorkResponseCode == HttpResponseCode::OK)
- {
- Result = WorkResponse.AsObject();
- }
- else if (WorkResponseCode == HttpResponseCode::NotFound)
+ if (m_EagerAttach)
{
- // Not all attachments are present
-
- // Build response package including all required attachments
-
CbPackage Pkg;
Pkg.SetObject(ActionObj);
- CbObject Response = WorkResponse.AsObject();
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash AttachHash = Field.AsHash();
- for (auto& Item : Response["need"sv])
- {
- const IoHash NeedHash = Item.AsHash();
-
- if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash))
{
uint64_t DataRawSize = 0;
IoHash DataRawHash;
CompressedBuffer Compressed =
CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
- ZEN_ASSERT(DataRawHash == NeedHash);
+ Pkg.AddAttachment(CbAttachment(Compressed, AttachHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ });
+
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, Pkg);
+ WorkResponseCode = WorkResponse.StatusCode;
+
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
- Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ (void)RegisterWorker(Action->Worker.Descriptor);
}
else
{
- // No such attachment
-
- return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ break;
}
}
+ }
+ else
+ {
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
+ WorkResponseCode = WorkResponse.StatusCode;
- // Post resulting package
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
- HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+ (void)RegisterWorker(Action->Worker.Descriptor);
+ }
+ else
+ {
+ break;
+ }
+ }
- if (!PayloadResponse)
+ if (WorkResponseCode == HttpResponseCode::NotFound)
{
- ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ // Remote needs attachments - resolve them and retry with a CbPackage
- // TODO: include more information about the failure in the response
+ CbPackage Pkg;
+ Pkg.SetObject(ActionObj);
- return {.IsAccepted = false, .Reason = "HTTP request failed"};
- }
- else if (PayloadResponse.StatusCode == HttpResponseCode::OK)
- {
- Result = PayloadResponse.AsObject();
- }
- else
- {
- // Unexpected response
-
- const int ResponseStatusCode = (int)PayloadResponse.StatusCode;
-
- ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})",
- ActionId,
- m_Http.GetBaseUri(),
- SubmitUrl,
- ResponseStatusCode,
- ToString(ResponseStatusCode));
-
- return {.IsAccepted = false,
- .Reason = fmt::format("unexpected response code {} {} from {}{}",
- ResponseStatusCode,
- ToString(ResponseStatusCode),
- m_Http.GetBaseUri(),
- SubmitUrl)};
+ CbObject Response = WorkResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ else
+ {
+ return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ }
+ }
+
+ HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (!PayloadResponse)
+ {
+ ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+
+ WorkResponse = std::move(PayloadResponse);
+ WorkResponseCode = WorkResponse.StatusCode;
}
}
+ if (WorkResponseCode == HttpResponseCode::OK)
+ {
+ Result = WorkResponse.AsObject();
+ }
+ else if (!WorkResponse)
+ {
+ ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+ else if (!IsHttpSuccessCode(WorkResponseCode))
+ {
+ const int Code = static_cast<int>(WorkResponseCode);
+ ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code));
+ return {.IsAccepted = false,
+ .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)};
+ }
+
if (Result)
{
if (const int32_t LsnField = Result["lsn"].AsInt32(0))
@@ -512,83 +562,111 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec
CbObjectWriter Body;
Body.BeginArray("actions"sv);
+ std::unordered_set<IoHash, IoHash::Hasher> AttachmentsSeen;
+
for (const Ref<RunnerAction>& Action : Actions)
{
Action->ExecutionLocation = m_HostName;
MaybeDumpAction(Action->ActionLsn, Action->ActionObj);
Body.AddObject(Action->ActionObj);
+
+ if (m_EagerAttach)
+ {
+ Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); });
+ }
}
Body.EndArray();
- // POST the batch
-
- HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
-
- if (Response.StatusCode == HttpResponseCode::OK)
- {
- return ParseBatchResponse(Response, Actions);
- }
+ // In eager-attach mode, build a CbPackage with all referenced attachments
+ // so the remote can accept in a single round-trip. Otherwise POST a bare
+ // CbObject and handle the 404 need-list flow.
- if (Response.StatusCode == HttpResponseCode::NotFound)
+ if (m_EagerAttach)
{
- // Server needs attachments — resolve them and retry with a CbPackage
-
- CbObject NeedObj = Response.AsObject();
-
CbPackage Pkg;
Pkg.SetObject(Body.Save());
- for (auto& Item : NeedObj["need"sv])
+ for (const IoHash& AttachHash : AttachmentsSeen)
{
- const IoHash NeedHash = Item.AsHash();
-
- if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash))
{
uint64_t DataRawSize = 0;
IoHash DataRawHash;
CompressedBuffer Compressed =
CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
- ZEN_ASSERT(DataRawHash == NeedHash);
-
- Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
- }
- else
- {
- ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash);
- return FallbackToIndividualSubmit(Actions);
+ Pkg.AddAttachment(CbAttachment(Compressed, AttachHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
}
}
- HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg);
- if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ if (Response.StatusCode == HttpResponseCode::OK)
{
- return ParseBatchResponse(RetryResponse, Actions);
+ return ParseBatchResponse(Response, Actions);
}
-
- ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit",
- (int)RetryResponse.StatusCode,
- ToString(RetryResponse.StatusCode));
- return FallbackToIndividualSubmit(Actions);
- }
-
- // Unexpected status or connection error — fall back to individual submission
-
- if (Response)
- {
- ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit",
- m_Http.GetBaseUri(),
- SubmitUrl,
- (int)Response.StatusCode,
- ToString(Response.StatusCode));
}
else
{
- ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
+
+ if (Response.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(Response, Actions);
+ }
+
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ CbObject NeedObj = Response.AsObject();
+
+ CbPackage Pkg;
+ Pkg.SetObject(Body.Save());
+
+ for (auto& Item : NeedObj["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ else
+ {
+ ZEN_WARN("batch submit: missing attachment {} - falling back to individual submit", NeedHash);
+ return FallbackToIndividualSubmit(Actions);
+ }
+ }
+
+ HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(RetryResponse, Actions);
+ }
+
+ ZEN_WARN("batch submit retry failed with {} {} - falling back to individual submit",
+ (int)RetryResponse.StatusCode,
+ ToString(RetryResponse.StatusCode));
+ return FallbackToIndividualSubmit(Actions);
+ }
}
+ // Unexpected status or connection error - fall back to individual submission
+
+ ZEN_WARN("batch submit to {}{} failed - falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
+
return FallbackToIndividualSubmit(Actions);
}
@@ -849,7 +927,7 @@ RemoteHttpRunner::MonitorThreadFunction()
SweepOnce();
}
- // Signal received — may be a WS wakeup or a quit signal
+ // Signal received - may be a WS wakeup or a quit signal
SweepOnce();
} while (m_MonitorThreadEnabled);
@@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions()
{
for (auto& FieldIt : Completed["completed"sv])
{
- CbObjectView EntryObj = FieldIt.AsObjectView();
- const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
- std::string_view StateName = EntryObj["state"sv].AsString();
+ CbObjectView EntryObj = FieldIt.AsObjectView();
+ const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
+ std::string_view StateName = EntryObj["state"sv].AsString();
+ std::string_view FailureReason = EntryObj["reason"sv].AsString();
RunnerAction::State RemoteState = RunnerAction::FromString(StateName);
@@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions()
{
HttpRunningAction CompletedAction = std::move(CompleteIt->second);
CompletedAction.RemoteState = RemoteState;
+ CompletedAction.FailureReason = std::string(FailureReason);
if (RemoteState == RunnerAction::State::Completed && ResponseJob)
{
@@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions()
{
const int ActionLsn = HttpAction.Action->ActionLsn;
- ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
- HttpAction.Action->ActionId,
- ActionLsn,
- HttpAction.RemoteActionLsn,
- RunnerAction::ToString(HttpAction.RemoteState));
-
if (HttpAction.RemoteState == RunnerAction::State::Completed)
{
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ m_HostName);
HttpAction.Action->SetResult(std::move(HttpAction.ActionResults));
}
+ else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned)
+ {
+ HttpAction.Action->FailureReason = HttpAction.FailureReason;
+ if (HttpAction.FailureReason.empty())
+ {
+ ZEN_WARN("action {} ({}) {} on remote {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState),
+ m_HostName);
+ }
+ else
+ {
+ ZEN_WARN("action {} ({}) {} on remote {}: {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState),
+ m_HostName,
+ HttpAction.FailureReason);
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState));
+ }
HttpAction.Action->SetActionState(HttpAction.RemoteState);
}
diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h
index c17d0cf2a..521bf2f82 100644
--- a/src/zencompute/runners/remotehttprunner.h
+++ b/src/zencompute/runners/remotehttprunner.h
@@ -54,8 +54,10 @@ public:
[[nodiscard]] virtual size_t QueryCapacity() override;
[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override;
virtual void CancelRemoteQueue(int QueueId) override;
+ [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; }
std::string_view GetHostName() const { return m_HostName; }
+ void SetRemoteHostname(std::string_view Hostname);
protected:
LoggerRef Log() { return m_Log; }
@@ -65,12 +67,15 @@ private:
ChunkResolver& m_ChunkResolver;
WorkerThreadPool& m_WorkerPool;
std::string m_HostName;
+ std::string m_DisplayName;
std::string m_BaseUrl;
HttpClient m_Http;
- std::atomic<bool> m_AcceptNewActions{true};
- int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
- int32_t m_MaxBatchSize = 50;
+ std::atomic<bool> m_AcceptNewActions{true};
+ int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ int32_t m_MaxBatchSize = 50;
+ bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry
+ std::atomic<size_t> m_InFlightSubmissions{0}; // actions currently being submitted over HTTP
struct HttpRunningAction
{
@@ -78,6 +83,7 @@ private:
int RemoteActionLsn = 0; // Remote LSN
RunnerAction::State RemoteState = RunnerAction::State::Failed;
CbPackage ActionResults;
+ std::string FailureReason;
};
RwLock m_RunningLock;
@@ -90,7 +96,7 @@ private:
size_t SweepRunningActions();
RwLock m_QueueTokenLock;
- std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token
+ std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId -> remote queue token
// Stable identity for this runner instance, used as part of the idempotency key when
// creating remote queues. Generated once at construction and never changes.
diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp
index cd4b646e9..c6b3e82ea 100644
--- a/src/zencompute/runners/windowsrunner.cpp
+++ b/src/zencompute/runners/windowsrunner.cpp
@@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
# include <sddl.h>
ZEN_THIRD_PARTY_INCLUDES_END
+// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be
+// excluded by WIN32_LEAN_AND_MEAN.
+# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE)
+# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400
+# endif
+
namespace zen::compute {
using namespace std::literals;
@@ -34,38 +40,67 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver,
: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
, m_Sandboxed(Sandboxed)
{
- if (!m_Sandboxed)
+ // Create a job object shared by all child processes. Restricting the
+ // error-mode UI prevents crash dialogs (WER / Dr. Watson) from
+ // blocking the monitor thread when a worker process terminates
+ // abnormally.
+ m_JobObject = CreateJobObjectW(nullptr, nullptr);
+ if (m_JobObject)
{
- return;
+ JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{};
+ ExtLimits.BasicLimitInformation.LimitFlags =
+ JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS;
+ ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS;
+ SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits));
+
+ JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{};
+ UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE;
+ SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions));
+
+ // Set error mode on this process so children inherit it. The
+ // UILIMIT_ERRORMODE restriction above prevents them from clearing
+ // SEM_NOGPFAULTERRORBOX.
+ SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX);
}
- // Build a unique profile name per process to avoid collisions
- m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
+ if (m_Sandboxed)
+ {
+ // Build a unique profile name per process to avoid collisions
+ m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
- // Clean up any stale profile from a previous crash
- DeleteAppContainerProfile(m_AppContainerName.c_str());
+ // Clean up any stale profile from a previous crash
+ DeleteAppContainerProfile(m_AppContainerName.c_str());
- PSID Sid = nullptr;
+ PSID Sid = nullptr;
- HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
- m_AppContainerName.c_str(), // display name
- m_AppContainerName.c_str(), // description
- nullptr, // no capabilities
- 0, // capability count
- &Sid);
+ HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
+ m_AppContainerName.c_str(), // display name
+ m_AppContainerName.c_str(), // description
+ nullptr, // no capabilities
+ 0, // capability count
+ &Sid);
- if (FAILED(Hr))
- {
- throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
- }
+ if (FAILED(Hr))
+ {
+ throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
+ }
- m_AppContainerSid = Sid;
+ m_AppContainerSid = Sid;
+
+ ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+ }
- ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+ StartMonitorThread();
}
WindowsProcessRunner::~WindowsProcessRunner()
{
+ if (m_JobObject)
+ {
+ CloseHandle(m_JobObject);
+ m_JobObject = nullptr;
+ }
+
if (m_AppContainerSid)
{
FreeSid(m_AppContainerSid);
@@ -172,9 +207,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr;
LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr;
BOOL bInheritHandles = FALSE;
- DWORD dwCreationFlags = DETACHED_PROCESS;
+ DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS;
- ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed);
+ ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath);
CommandLine.EnsureNulTerminated();
@@ -260,14 +295,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
}
}
- CloseHandle(ProcessInformation.hThread);
+ if (m_JobObject)
+ {
+ AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess);
+ }
- Ref<RunningAction> NewAction{new RunningAction()};
- NewAction->Action = Action;
- NewAction->ProcessHandle = ProcessInformation.hProcess;
- NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+ ResumeThread(ProcessInformation.hThread);
+ CloseHandle(ProcessInformation.hThread);
{
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = ProcessInformation.hProcess;
+ NewAction->Pid = ProcessInformation.dwProcessId;
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
RwLock::ExclusiveLockScope _(m_RunningLock);
m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
@@ -275,6 +317,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
Action->SetActionState(RunnerAction::State::Running);
+ ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId);
+
return SubmitResult{.IsAccepted = true};
}
@@ -294,6 +338,11 @@ WindowsProcessRunner::SweepRunningActions()
if (IsSuccess && ExitCode != STILL_ACTIVE)
{
+ ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"),
+ Running->Action->ActionLsn,
+ Running->Pid,
+ ExitCode);
+
CloseHandle(Running->ProcessHandle);
Running->ProcessHandle = INVALID_HANDLE_VALUE;
Running->ExitCode = ExitCode;
@@ -436,7 +485,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running)
const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime);
const uint64_t NowTicks = GetHifreqTimerValue();
- // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds
+ // Cumulative CPU seconds (absolute, available from first sample): 100ns -> seconds
Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed);
if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
@@ -445,7 +494,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running)
if (ElapsedMs > 0)
{
const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
- // 100ns → ms: divide by 10000; then as percent of elapsed ms
+ // 100ns -> ms: divide by 10000; then as percent of elapsed ms
const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0);
Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h
index 9f2385cc4..adeaf02fc 100644
--- a/src/zencompute/runners/windowsrunner.h
+++ b/src/zencompute/runners/windowsrunner.h
@@ -46,6 +46,7 @@ private:
bool m_Sandboxed = false;
PSID m_AppContainerSid = nullptr;
std::wstring m_AppContainerName;
+ HANDLE m_JobObject = nullptr;
};
} // namespace zen::compute
diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp
index 506bec73b..29ab93663 100644
--- a/src/zencompute/runners/winerunner.cpp
+++ b/src/zencompute/runners/winerunner.cpp
@@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver,
sigemptyset(&Action.sa_mask);
Action.sa_handler = SIG_DFL;
sigaction(SIGCHLD, &Action, nullptr);
+
+ StartMonitorThread();
}
SubmitResult
@@ -94,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process - lower priority so workers don't starve the main server
+ nice(5);
+
if (chdir(SandboxPathStr.c_str()) != 0)
{
_exit(127);
diff --git a/src/zencore/compactbinaryjson.cpp b/src/zencore/compactbinaryjson.cpp
index da560a449..5bfbd5e3e 100644
--- a/src/zencore/compactbinaryjson.cpp
+++ b/src/zencore/compactbinaryjson.cpp
@@ -11,6 +11,8 @@
#include <zencore/testing.h>
#include <fmt/format.h>
+#include <cmath>
+#include <limits>
#include <vector>
ZEN_THIRD_PARTY_INCLUDES_START
@@ -570,13 +572,37 @@ private:
break;
case Json::Type::NUMBER:
{
- if (FieldName.empty())
- {
- Writer.AddFloat(Json.number_value());
+ // If the JSON number has no fractional part and fits in an int64,
+ // store it as an integer so that AsInt32/AsInt64 work without
+ // requiring callers to go through AsFloat.
+ double Value = Json.number_value();
+ double IntPart;
+ bool IsIntegral = (std::modf(Value, &IntPart) == 0.0) &&
+ Value >= static_cast<double>(std::numeric_limits<int64_t>::min()) &&
+ Value <= static_cast<double>(std::numeric_limits<int64_t>::max());
+
+ if (IsIntegral)
+ {
+ int64_t IntValue = static_cast<int64_t>(Value);
+ if (FieldName.empty())
+ {
+ Writer.AddInteger(IntValue);
+ }
+ else
+ {
+ Writer.AddInteger(FieldName, IntValue);
+ }
}
else
{
- Writer.AddFloat(FieldName, Json.number_value());
+ if (FieldName.empty())
+ {
+ Writer.AddFloat(Value);
+ }
+ else
+ {
+ Writer.AddFloat(FieldName, Value);
+ }
}
}
break;
diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp
index 56a292ca6..87b58baf7 100644
--- a/src/zencore/compactbinarypackage.cpp
+++ b/src/zencore/compactbinarypackage.cpp
@@ -684,14 +684,22 @@ namespace legacy {
Writer.Save(Ar);
}
- bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper)
+ bool TryLoadCbPackage(CbPackage& Package,
+ IoBuffer InBuffer,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper,
+ bool ValidateHashes)
{
BinaryReader Reader(InBuffer.Data(), InBuffer.Size());
- return TryLoadCbPackage(Package, Reader, Allocator, Mapper);
+ return TryLoadCbPackage(Package, Reader, Allocator, Mapper, ValidateHashes);
}
- bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper)
+ bool TryLoadCbPackage(CbPackage& Package,
+ BinaryReader& Reader,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper,
+ bool ValidateHashes)
{
Package = CbPackage();
for (;;)
@@ -708,7 +716,11 @@ namespace legacy {
if (ValueField.IsBinary())
{
const MemoryView View = ValueField.AsBinaryView();
- if (View.GetSize() > 0)
+ if (View.GetSize() == 0)
+ {
+ return false;
+ }
+ else
{
SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned();
CbField HashField = LoadCompactBinary(Reader, Allocator);
@@ -748,7 +760,11 @@ namespace legacy {
{
const IoHash Hash = ValueField.AsHash();
- ZEN_ASSERT(Mapper);
+ if (!Mapper)
+ {
+ return false;
+ }
+
if (SharedBuffer AttachmentData = (*Mapper)(Hash))
{
IoHash RawHash;
@@ -763,6 +779,10 @@ namespace legacy {
}
else
{
+ if (ValidateHashes && IoHash::HashBuffer(AttachmentData) != Hash)
+ {
+ return false;
+ }
const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All);
if (ValidationResult != CbValidateError::None)
{
@@ -801,13 +821,13 @@ namespace legacy {
#if ZEN_WITH_TESTS
void
-usonpackage_forcelink()
+cbpackage_forcelink()
{
}
TEST_SUITE_BEGIN("core.compactbinarypackage");
-TEST_CASE("usonpackage")
+TEST_CASE("cbpackage")
{
using namespace std::literals;
@@ -997,7 +1017,7 @@ TEST_CASE("usonpackage")
}
}
-TEST_CASE("usonpackage.serialization")
+TEST_CASE("cbpackage.serialization")
{
using namespace std::literals;
@@ -1303,7 +1323,7 @@ TEST_CASE("usonpackage.serialization")
}
}
-TEST_CASE("usonpackage.invalidpackage")
+TEST_CASE("cbpackage.invalidpackage")
{
const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) {
const MemoryView RawView = MakeMemoryView(RawData);
@@ -1345,6 +1365,90 @@ TEST_CASE("usonpackage.invalidpackage")
}
}
+TEST_CASE("cbpackage.legacy.invalidpackage")
+{
+ const auto TestLegacyLoad = [](std::initializer_list<uint8_t> RawData) {
+ const MemoryView RawView = MakeMemoryView(RawData);
+ IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(RawView.GetData()), RawView.GetSize());
+ CbPackage Package;
+ CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc));
+ };
+
+ SUBCASE("Empty") { TestLegacyLoad({}); }
+
+ SUBCASE("Zero size binary rejects")
+ {
+ // A binary field with zero payload size should be rejected (would desync the reader)
+ BinaryWriter Writer;
+ CbWriter Cb;
+ Cb.AddBinary(MemoryView()); // zero-size binary
+ Cb.Save(Writer);
+
+ IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize());
+ CbPackage Package;
+ CHECK_FALSE(legacy::TryLoadCbPackage(Package, Buffer, &UniqueBuffer::Alloc));
+ }
+}
+
+TEST_CASE("cbpackage.legacy.hashresolution")
+{
+ // Build a valid legacy package with an object, then round-trip it
+ CbObjectWriter RootWriter;
+ RootWriter.AddString("name", "test");
+ CbObject RootObject = RootWriter.Save();
+
+ CbAttachment ObjectAttach(RootObject);
+
+ CbPackage OriginalPkg;
+ OriginalPkg.SetObject(RootObject);
+ OriginalPkg.AddAttachment(ObjectAttach);
+
+ BinaryWriter Writer;
+ legacy::SaveCbPackage(OriginalPkg, Writer);
+
+ IoBuffer Buffer(IoBuffer::Wrap, const_cast<void*>(MakeMemoryView(Writer).GetData()), MakeMemoryView(Writer).GetSize());
+ CbPackage LoadedPkg;
+ CHECK(legacy::TryLoadCbPackage(LoadedPkg, Buffer, &UniqueBuffer::Alloc));
+
+ // The hash-only path requires a mapper - without one it should fail
+ CbWriter HashOnlyCb;
+ HashOnlyCb.AddHash(ObjectAttach.GetHash());
+ HashOnlyCb.AddNull();
+ BinaryWriter HashOnlyWriter;
+ HashOnlyCb.Save(HashOnlyWriter);
+
+ IoBuffer HashOnlyBuffer(IoBuffer::Wrap,
+ const_cast<void*>(MakeMemoryView(HashOnlyWriter).GetData()),
+ MakeMemoryView(HashOnlyWriter).GetSize());
+ CbPackage HashOnlyPkg;
+ CHECK_FALSE(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, nullptr));
+
+ // With a mapper that returns valid data, it should succeed
+ CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer {
+ if (Hash == ObjectAttach.GetHash())
+ {
+ return RootObject.GetBuffer();
+ }
+ return {};
+ };
+ CHECK(legacy::TryLoadCbPackage(HashOnlyPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &Resolver));
+
+ // Build a different but structurally valid CbObject to use as mismatched data
+ CbObjectWriter DifferentWriter;
+ DifferentWriter.AddString("name", "different");
+ CbObject DifferentObject = DifferentWriter.Save();
+
+ CbPackage::AttachmentResolver BadResolver = [&](const IoHash&) -> SharedBuffer { return DifferentObject.GetBuffer(); };
+ CbPackage BadPkg;
+
+ // With ValidateHashes enabled and a mapper that returns mismatched data, it should fail
+ CHECK_FALSE(legacy::TryLoadCbPackage(BadPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ true));
+
+ // Without ValidateHashes, the mismatched data is accepted (structure is still valid CB)
+ CbPackage UncheckedPkg;
+ CHECK(legacy::TryLoadCbPackage(UncheckedPkg, HashOnlyBuffer, &UniqueBuffer::Alloc, &BadResolver, /*ValidateHashes*/ false));
+}
+
TEST_SUITE_END();
#endif
diff --git a/src/zencore/crashhandler.cpp b/src/zencore/crashhandler.cpp
index 31b8e6ce2..14904a4b2 100644
--- a/src/zencore/crashhandler.cpp
+++ b/src/zencore/crashhandler.cpp
@@ -56,7 +56,7 @@ CrashExceptionFilter(PEXCEPTION_POINTERS ExceptionInfo)
HANDLE Process = GetCurrentProcess();
HANDLE Thread = GetCurrentThread();
- // SymInitialize is safe to call if already initialized — it returns FALSE
+ // SymInitialize is safe to call if already initialized - it returns FALSE
// but existing state remains valid for SymFromAddr calls
SymInitialize(Process, NULL, TRUE);
diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp
index 0d361801f..5160bfdc6 100644
--- a/src/zencore/filesystem.cpp
+++ b/src/zencore/filesystem.cpp
@@ -114,6 +114,20 @@ struct ScopedFd
explicit operator bool() const { return Fd >= 0; }
};
+# if ZEN_PLATFORM_LINUX
+inline uint64_t
+StatMtime100Ns(const struct stat& S)
+{
+ return uint64_t(S.st_mtim.tv_sec) * 10'000'000ULL + uint64_t(S.st_mtim.tv_nsec) / 100;
+}
+# elif ZEN_PLATFORM_MAC
+inline uint64_t
+StatMtime100Ns(const struct stat& S)
+{
+ return uint64_t(S.st_mtimespec.tv_sec) * 10'000'000ULL + uint64_t(S.st_mtimespec.tv_nsec) / 100;
+}
+# endif
+
#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
#if ZEN_PLATFORM_WINDOWS
@@ -2123,7 +2137,7 @@ FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, Tr
}
else if (S_ISREG(Stat.st_mode))
{
- Visitor.VisitFile(RootDir, FileName, Stat.st_size, gsl::narrow<uint32_t>(Stat.st_mode), gsl::narrow<uint64_t>(Stat.st_mtime));
+ Visitor.VisitFile(RootDir, FileName, Stat.st_size, gsl::narrow<uint32_t>(Stat.st_mode), StatMtime100Ns(Stat));
}
else
{
@@ -2507,7 +2521,7 @@ GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec)
struct stat Stat;
if (0 == fstat(Fd, &Stat))
{
- return gsl::narrow<uint64_t>(Stat.st_mtime);
+ return StatMtime100Ns(Stat);
}
#endif
Ec = MakeErrorCodeFromLastError();
@@ -2546,7 +2560,7 @@ GetModificationTickFromPath(const std::filesystem::path& Filename)
{
ThrowLastError(fmt::format("Failed to get mode of file {}", Filename));
}
- return gsl::narrow<uint64_t>(Stat.st_mtime);
+ return StatMtime100Ns(Stat);
#endif
}
@@ -2589,7 +2603,7 @@ TryGetFileProperties(const std::filesystem::path& Path,
{
return false;
}
- OutModificationTick = gsl::narrow<uint64_t>(Stat.st_mtime);
+ OutModificationTick = StatMtime100Ns(Stat);
OutSize = size_t(Stat.st_size);
OutNativeModeOrAttributes = (uint32_t)Stat.st_mode;
return true;
@@ -2963,6 +2977,35 @@ GetEnvVariable(std::string_view VariableName)
return "";
}
+std::string
+ExpandEnvironmentVariables(std::string_view Input)
+{
+ std::string Result;
+ Result.reserve(Input.size());
+
+ for (size_t i = 0; i < Input.size(); ++i)
+ {
+ if (Input[i] == '%')
+ {
+ size_t End = Input.find('%', i + 1);
+ if (End != std::string_view::npos && End > i + 1)
+ {
+ std::string_view VarName = Input.substr(i + 1, End - i - 1);
+ std::string Value = GetEnvVariable(VarName);
+ if (!Value.empty())
+ {
+ Result += Value;
+ i = End;
+ continue;
+ }
+ }
+ }
+ Result += Input[i];
+ }
+
+ return Result;
+}
+
std::error_code
RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles)
{
@@ -3275,14 +3318,25 @@ MakeSafeAbsolutePathInPlace(std::filesystem::path& Path)
{
if (!Path.empty())
{
- std::filesystem::path AbsolutePath = std::filesystem::absolute(Path).make_preferred();
+ Path = std::filesystem::absolute(Path).make_preferred();
#if ZEN_PLATFORM_WINDOWS
- const std::string_view Prefix = "\\\\?\\";
- const std::u8string PrefixU8(Prefix.begin(), Prefix.end());
- std::u8string PathString = AbsolutePath.u8string();
- if (!PathString.empty() && !PathString.starts_with(PrefixU8))
+ const std::u8string_view LongPathPrefix = u8"\\\\?\\";
+ const std::u8string_view UncPrefix = u8"\\\\";
+ const std::u8string_view LongPathUncPrefix = u8"\\\\?\\UNC\\";
+
+ std::u8string PathString = Path.u8string();
+ if (!PathString.empty() && !PathString.starts_with(LongPathPrefix))
{
- PathString.insert(0, PrefixU8);
+ if (PathString.starts_with(UncPrefix))
+ {
+ // UNC path: \\server\share -> \\?\UNC\server\share
+ PathString.replace(0, UncPrefix.size(), LongPathUncPrefix);
+ }
+ else
+ {
+ // Local path: C:\foo -> \\?\C:\foo
+ PathString.insert(0, LongPathPrefix);
+ }
Path = PathString;
}
#endif // ZEN_PLATFORM_WINDOWS
@@ -3408,7 +3462,7 @@ public:
ZEN_UNUSED(SystemGlobal);
std::string InstanceMapName = fmt::format("/{}", Name);
- ScopedFd FdGuard(shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666));
+ ScopedFd FdGuard(shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT, 0666));
if (!FdGuard)
{
return {};
@@ -4049,6 +4103,93 @@ TEST_CASE("SharedMemory")
CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false));
}
+TEST_CASE("filesystem.MakeSafeAbsolutePath")
+{
+# if ZEN_PLATFORM_WINDOWS
+ // Local path gets \\?\ prefix
+ {
+ std::filesystem::path Local = MakeSafeAbsolutePath("C:\\Users\\test");
+ CHECK(Local.u8string().starts_with(u8"\\\\?\\"));
+ CHECK(Local.u8string().find(u8"C:\\Users\\test") != std::u8string::npos);
+ }
+
+ // UNC path gets \\?\UNC\ prefix
+ {
+ std::filesystem::path Unc = MakeSafeAbsolutePath("\\\\server\\share\\path");
+ std::u8string UncStr = Unc.u8string();
+ CHECK_MESSAGE(UncStr.starts_with(u8"\\\\?\\UNC\\"), fmt::format("Expected \\\\?\\UNC\\ prefix, got '{}'", Unc));
+ CHECK_MESSAGE(UncStr.find(u8"server\\share\\path") != std::u8string::npos,
+ fmt::format("Expected server\\share\\path in '{}'", Unc));
+ // Must NOT produce \\?\\\server (double backslash after \\?\)
+ CHECK_MESSAGE(UncStr.find(u8"\\\\?\\\\\\") == std::u8string::npos,
+ fmt::format("Path contains invalid double-backslash after prefix: '{}'", Unc));
+ }
+
+ // Already-prefixed path is not double-prefixed
+ {
+ std::filesystem::path Already = MakeSafeAbsolutePath("\\\\?\\C:\\already\\prefixed");
+ size_t Count = 0;
+ std::u8string Str = Already.u8string();
+ for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1))
+ {
+ ++Count;
+ }
+ CHECK_EQ(Count, 1);
+ }
+
+ // Already-prefixed UNC path is not double-prefixed
+ {
+ std::filesystem::path AlreadyUnc = MakeSafeAbsolutePath("\\\\?\\UNC\\server\\share");
+ size_t Count = 0;
+ std::u8string Str = AlreadyUnc.u8string();
+ for (size_t Pos = Str.find(u8"\\\\?\\"); Pos != std::u8string::npos; Pos = Str.find(u8"\\\\?\\", Pos + 1))
+ {
+ ++Count;
+ }
+ CHECK_EQ(Count, 1);
+ }
+# endif // ZEN_PLATFORM_WINDOWS
+}
+
+TEST_CASE("ExpandEnvironmentVariables")
+{
+ // No variables - pass-through
+ CHECK_EQ(ExpandEnvironmentVariables("plain/path"), "plain/path");
+ CHECK_EQ(ExpandEnvironmentVariables(""), "");
+
+ // Single percent sign is not a variable reference
+ CHECK_EQ(ExpandEnvironmentVariables("50%"), "50%");
+
+ // Empty variable name (%%) is not expanded
+ CHECK_EQ(ExpandEnvironmentVariables("%%"), "%%");
+
+ // Known variable
+# if ZEN_PLATFORM_WINDOWS
+ // PATH is always set on Windows
+ std::string PathValue = GetEnvVariable("PATH");
+ CHECK(!PathValue.empty());
+ CHECK_EQ(ExpandEnvironmentVariables("%PATH%"), PathValue);
+ CHECK_EQ(ExpandEnvironmentVariables("prefix/%PATH%/suffix"), "prefix/" + PathValue + "/suffix");
+# else
+ std::string HomeValue = GetEnvVariable("HOME");
+ CHECK(!HomeValue.empty());
+ CHECK_EQ(ExpandEnvironmentVariables("%HOME%"), HomeValue);
+ CHECK_EQ(ExpandEnvironmentVariables("prefix/%HOME%/suffix"), "prefix/" + HomeValue + "/suffix");
+# endif
+
+ // Unknown variable is left unexpanded
+ CHECK_EQ(ExpandEnvironmentVariables("%ZEN_UNLIKELY_SET_VAR_12345%"), "%ZEN_UNLIKELY_SET_VAR_12345%");
+
+ // Multiple variables
+# if ZEN_PLATFORM_WINDOWS
+ std::string OSValue = GetEnvVariable("OS");
+ if (!OSValue.empty())
+ {
+ CHECK_EQ(ExpandEnvironmentVariables("%PATH%/%OS%"), PathValue + "/" + OSValue);
+ }
+# endif
+}
+
TEST_SUITE_END();
#endif
diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h
index 64b62e2c0..148c0d3fd 100644
--- a/src/zencore/include/zencore/compactbinarypackage.h
+++ b/src/zencore/include/zencore/compactbinarypackage.h
@@ -278,10 +278,10 @@ public:
* @return The attachment, or null if the attachment is not found.
* @note The returned pointer is only valid until the attachments on this package are modified.
*/
- const CbAttachment* FindAttachment(const IoHash& Hash) const;
+ [[nodiscard]] const CbAttachment* FindAttachment(const IoHash& Hash) const;
/** Find an attachment if it exists in the package. */
- inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); }
+ [[nodiscard]] const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); }
/** Add the attachment to this package. */
inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); }
@@ -336,17 +336,26 @@ private:
IoHash ObjectHash;
};
+/** In addition to the above, we also support a legacy format which is used by
+ * the HTTP project store for historical reasons. Don't use the below functions
+ * for new code.
+ */
namespace legacy {
void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer);
void SaveCbPackage(const CbPackage& Package, CbWriter& Writer);
void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar);
- bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr);
+ bool TryLoadCbPackage(CbPackage& Package,
+ IoBuffer Buffer,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper = nullptr,
+ bool ValidateHashes = false);
bool TryLoadCbPackage(CbPackage& Package,
BinaryReader& Reader,
BufferAllocator Allocator,
- CbPackage::AttachmentResolver* Mapper = nullptr);
+ CbPackage::AttachmentResolver* Mapper = nullptr,
+ bool ValidateHashes = false);
} // namespace legacy
-void usonpackage_forcelink(); // internal
+void cbpackage_forcelink(); // internal
} // namespace zen
diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h
index 6dc159a83..73769cdb4 100644
--- a/src/zencore/include/zencore/filesystem.h
+++ b/src/zencore/include/zencore/filesystem.h
@@ -400,6 +400,10 @@ void GetDirectoryContent(const std::filesystem::path& RootDir,
std::string GetEnvVariable(std::string_view VariableName);
+// Expands %VAR% environment variable references in a string.
+// Unknown or empty variables are left unexpanded.
+std::string ExpandEnvironmentVariables(std::string_view Input);
+
std::filesystem::path SearchPathForExecutable(std::string_view ExecutableName);
std::error_code RotateFiles(const std::filesystem::path& Filename, std::size_t MaxFiles);
diff --git a/src/zencore/include/zencore/fmtutils.h b/src/zencore/include/zencore/fmtutils.h
index 4ec05f901..a263c6f04 100644
--- a/src/zencore/include/zencore/fmtutils.h
+++ b/src/zencore/include/zencore/fmtutils.h
@@ -3,10 +3,7 @@
#pragma once
#include <zencore/filesystem.h>
-#include <zencore/guid.h>
-#include <zencore/iohash.h>
#include <zencore/string.h>
-#include <zencore/uid.h>
ZEN_THIRD_PARTY_INCLUDES_START
#include <fmt/format.h>
@@ -38,63 +35,49 @@ struct fmt::formatter<T> : fmt::formatter<std::string_view>
}
};
-// Custom formatting for some zencore types
+// Generic formatter for any type that is explicitly convertible to std::string_view.
+// This covers NiceNum, NiceBytes, ThousandsNum, StringBuilder, and similar types
+// without needing per-type fmt::formatter specializations.
template<typename T>
-requires DerivedFrom<T, zen::StringBuilderBase>
-struct fmt::formatter<T> : fmt::formatter<std::string_view>
+concept HasStringViewConversion = std::is_class_v<T> && requires(const T& v)
{
- template<typename FormatContext>
- auto format(const zen::StringBuilderBase& a, FormatContext& ctx) const
{
- return fmt::formatter<std::string_view>::format(a.ToView(), ctx);
- }
-};
+ std::string_view(v)
+ } -> std::same_as<std::string_view>;
+} && !HasFreeToString<T> && !std::is_same_v<T, std::string> && !std::is_same_v<T, std::string_view>;
-template<typename T>
-requires DerivedFrom<T, zen::NiceBase>
+template<HasStringViewConversion T>
struct fmt::formatter<T> : fmt::formatter<std::string_view>
{
template<typename FormatContext>
- auto format(const zen::NiceBase& a, FormatContext& ctx) const
+ auto format(const T& Value, FormatContext& ctx) const
{
- return fmt::formatter<std::string_view>::format(std::string_view(a), ctx);
+ return fmt::formatter<std::string_view>::format(std::string_view(Value), ctx);
}
};
-template<>
-struct fmt::formatter<zen::IoHash> : formatter<string_view>
-{
- template<typename FormatContext>
- auto format(const zen::IoHash& Hash, FormatContext& ctx) const
- {
- zen::IoHash::String_t String;
- Hash.ToHexString(String);
- return fmt::formatter<string_view>::format({String, zen::IoHash::StringLength}, ctx);
- }
-};
+// Generic formatter for any type with a ToString(StringBuilderBase&) member function.
+// This covers Guid, IoHash, Oid, and similar types without needing per-type
+// fmt::formatter specializations.
-template<>
-struct fmt::formatter<zen::Oid> : formatter<string_view>
+template<typename T>
+concept HasMemberToStringBuilder = std::is_class_v<T> && requires(const T& v, zen::StringBuilderBase& sb)
{
- template<typename FormatContext>
- auto format(const zen::Oid& Id, FormatContext& ctx) const
{
- zen::StringBuilder<32> String;
- Id.ToString(String);
- return fmt::formatter<string_view>::format({String.c_str(), zen::Oid::StringLength}, ctx);
- }
-};
+ v.ToString(sb)
+ } -> std::same_as<zen::StringBuilderBase&>;
+} && !HasFreeToString<T> && !HasStringViewConversion<T>;
-template<>
-struct fmt::formatter<zen::Guid> : formatter<string_view>
+template<HasMemberToStringBuilder T>
+struct fmt::formatter<T> : fmt::formatter<std::string_view>
{
template<typename FormatContext>
- auto format(const zen::Guid& Id, FormatContext& ctx) const
+ auto format(const T& Value, FormatContext& ctx) const
{
- zen::StringBuilder<48> String;
- Id.ToString(String);
- return fmt::formatter<string_view>::format({String.c_str(), zen::Guid::StringLength}, ctx);
+ zen::ExtendableStringBuilder<64> String;
+ Value.ToString(String);
+ return fmt::formatter<std::string_view>::format(String.ToView(), ctx);
}
};
diff --git a/src/zencore/include/zencore/hashutils.h b/src/zencore/include/zencore/hashutils.h
index 8abfd4b6e..e253d7015 100644
--- a/src/zencore/include/zencore/hashutils.h
+++ b/src/zencore/include/zencore/hashutils.h
@@ -4,6 +4,8 @@
#include <cstddef>
#include <functional>
+#include <string>
+#include <string_view>
#include <type_traits>
namespace zen {
@@ -35,4 +37,21 @@ CombineHashes(const Types&... Args)
return Seed;
}
+/** Transparent string hash for use with std::unordered_map/set.
+ Enables heterogeneous lookup so that a std::string_view can be used to
+ probe a std::string-keyed container without allocating a temporary std::string.
+
+ Usage:
+ std::unordered_map<std::string, V, TransparentStringHash, std::equal_to<>> Map;
+ Map.find(some_string_view); // no allocation
+ */
+struct TransparentStringHash
+{
+ using is_transparent = void;
+
+ size_t operator()(std::string_view Sv) const noexcept { return std::hash<std::string_view>{}(Sv); }
+ size_t operator()(const std::string& S) const noexcept { return std::hash<std::string_view>{}(S); }
+ size_t operator()(const char* S) const noexcept { return std::hash<std::string_view>{}(S); }
+};
+
} // namespace zen
diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h
index 82c201edd..c6ba90692 100644
--- a/src/zencore/include/zencore/iobuffer.h
+++ b/src/zencore/include/zencore/iobuffer.h
@@ -109,10 +109,11 @@ public:
// Reference counting
- inline uint32_t AddRef() const { return AtomicIncrement(const_cast<IoBufferCore*>(this)->m_RefCount); }
- inline uint32_t Release() const
+ // See zen::RefCounted::AddRef/Release for ordering rationale.
+ inline uint32_t AddRef() const noexcept { return m_RefCount.fetch_add(1, std::memory_order_relaxed) + 1; }
+ inline uint32_t Release() const noexcept
{
- const uint32_t NewRefCount = AtomicDecrement(const_cast<IoBufferCore*>(this)->m_RefCount);
+ const uint32_t NewRefCount = m_RefCount.fetch_sub(1, std::memory_order_acq_rel) - 1;
if (NewRefCount == 0)
{
DeleteThis();
@@ -130,7 +131,7 @@ public:
//
void Materialize() const;
- void DeleteThis() const;
+ void DeleteThis() const noexcept;
void MakeOwned(bool Immutable = true);
inline void EnsureDataValid() const
@@ -228,14 +229,14 @@ public:
return ZenContentType((m_Flags.load(std::memory_order_relaxed) >> kContentTypeShift) & kContentTypeMask);
}
- inline uint32_t GetRefCount() const { return m_RefCount; }
+ inline uint32_t GetRefCount() const noexcept { return m_RefCount.load(std::memory_order_relaxed); }
protected:
- uint32_t m_RefCount = 0;
+ mutable std::atomic<uint32_t> m_RefCount = 0;
mutable std::atomic<uint32_t> m_Flags{0};
mutable const void* m_DataPtr = nullptr;
size_t m_DataBytes = 0;
- RefPtr<const IoBufferCore> m_OuterCore;
+ Ref<const IoBufferCore> m_OuterCore;
enum
{
@@ -413,9 +414,9 @@ public:
private:
// We have a shared "null" buffer core which we share, this is initialized static and never released which will
// cause a memory leak at exit. This does however save millions of memory allocations for null buffers
- static RefPtr<IoBufferCore> NullBufferCore;
+ static Ref<IoBufferCore> NullBufferCore;
- RefPtr<IoBufferCore> m_Core = NullBufferCore;
+ Ref<IoBufferCore> m_Core = NullBufferCore;
IoBuffer(IoBufferCore* Core) : m_Core(Core) {}
diff --git a/src/zencore/include/zencore/iohash.h b/src/zencore/include/zencore/iohash.h
index a619b0053..50c439b70 100644
--- a/src/zencore/include/zencore/iohash.h
+++ b/src/zencore/include/zencore/iohash.h
@@ -54,6 +54,7 @@ struct IoHash
static bool TryParse(std::string_view Str, IoHash& Hash);
const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const;
StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const;
+ StringBuilderBase& ToString(StringBuilderBase& outBuilder) const { return ToHexString(outBuilder); }
std::string ToHexString() const;
static constexpr int StringLength = 40;
diff --git a/src/zencore/include/zencore/logbase.h b/src/zencore/include/zencore/logbase.h
index ad2ab218d..65f8a9dbe 100644
--- a/src/zencore/include/zencore/logbase.h
+++ b/src/zencore/include/zencore/logbase.h
@@ -8,7 +8,7 @@
#include <string_view>
namespace zen::logging {
-enum LogLevel : int
+enum LogLevel : int8_t
{
Trace,
Debug,
@@ -22,6 +22,7 @@ enum LogLevel : int
LogLevel ParseLogLevelString(std::string_view String);
std::string_view ToStringView(LogLevel Level);
+std::string_view ShortToStringView(LogLevel Level);
void SetLogLevel(LogLevel NewLogLevel);
LogLevel GetLogLevel();
@@ -49,11 +50,16 @@ struct SourceLocation
*/
struct LogPoint
{
- SourceLocation Location;
+ const char* Filename;
+ int Line;
LogLevel Level;
std::string_view FormatString;
+
+ [[nodiscard]] SourceLocation Location() const { return SourceLocation{Filename, Line}; }
};
+static_assert(sizeof(LogPoint) <= 32);
+
class Logger;
/** This is the base class for all loggers
@@ -91,6 +97,7 @@ struct LoggerRef
{
LoggerRef() = default;
explicit LoggerRef(logging::Logger& InLogger);
+ explicit LoggerRef(std::string_view LogCategory);
// This exists so that logging macros can pass LoggerRef or LogCategory
// to ZEN_LOG without needing to know which one it is
@@ -104,6 +111,8 @@ struct LoggerRef
bool ShouldLog(logging::LogLevel Level) const { return m_Logger->ShouldLog(Level); }
void SetLogLevel(logging::LogLevel NewLogLevel) { m_Logger->SetLevel(NewLogLevel); }
logging::LogLevel GetLogLevel() { return m_Logger->GetLevel(); }
+ std::string_view GetLogLevelString() { return logging::ToStringView(GetLogLevel()); }
+ std::string_view GetShortLogLevelString() { return logging::ShortToStringView(GetLogLevel()); }
void Flush();
diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h
index 3427991d2..cf011fb1a 100644
--- a/src/zencore/include/zencore/logging.h
+++ b/src/zencore/include/zencore/logging.h
@@ -90,6 +90,34 @@ using zen::ConsoleLog;
using zen::ErrorLog;
using zen::Log;
+////////////////////////////////////////////////////////////////////////
+// Color helpers
+
+#define ZEN_RED(str) "\033[31m" str "\033[0m"
+#define ZEN_GREEN(str) "\033[32m" str "\033[0m"
+#define ZEN_YELLOW(str) "\033[33m" str "\033[0m"
+#define ZEN_BLUE(str) "\033[34m" str "\033[0m"
+#define ZEN_MAGENTA(str) "\033[35m" str "\033[0m"
+#define ZEN_CYAN(str) "\033[36m" str "\033[0m"
+#define ZEN_WHITE(str) "\033[37m" str "\033[0m"
+
+#define ZEN_BRIGHT_RED(str) "\033[91m" str "\033[0m"
+#define ZEN_BRIGHT_GREEN(str) "\033[92m" str "\033[0m"
+#define ZEN_BRIGHT_YELLOW(str) "\033[93m" str "\033[0m"
+#define ZEN_BRIGHT_BLUE(str) "\033[94m" str "\033[0m"
+#define ZEN_BRIGHT_MAGENTA(str) "\033[95m" str "\033[0m"
+#define ZEN_BRIGHT_CYAN(str) "\033[96m" str "\033[0m"
+#define ZEN_BRIGHT_WHITE(str) "\033[97m" str "\033[0m"
+
+#define ZEN_BOLD(str) "\033[1m" str "\033[0m"
+#define ZEN_UNDERLINE(str) "\033[4m" str "\033[0m"
+#define ZEN_DIM(str) "\033[2m" str "\033[0m"
+#define ZEN_ITALIC(str) "\033[3m" str "\033[0m"
+#define ZEN_STRIKETHROUGH(str) "\033[9m" str "\033[0m"
+#define ZEN_INVERSE(str) "\033[7m" str "\033[0m"
+
+////////////////////////////////////////////////////////////////////////
+
#if ZEN_BUILD_DEBUG
# define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \
while (false) \
@@ -103,31 +131,31 @@ using zen::Log;
}
#endif
-#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- static constinit ZEN_LOG_SECTION(".zlog$l") \
- zen::logging::LogPoint LogPoint{zen::logging::SourceLocation{__FILE__, __LINE__}, InLevel, std::string_view(fmtstr)}; \
- zen::LoggerRef Logger = InLogger; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- if (Logger.ShouldLog(InLevel)) \
- { \
- zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
- } \
+#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit ZEN_LOG_SECTION(".zlog$l") \
+ zen::logging::LogPoint LogPoint{__FILE__, __LINE__, InLevel, std::string_view(fmtstr)}; \
+ zen::LoggerRef Logger = InLogger; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ if (Logger.ShouldLog(InLevel)) \
+ { \
+ zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+ } \
} while (false);
-#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \
- zen::LoggerRef Logger = InLogger; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- if (Logger.ShouldLog(InLevel)) \
- { \
- zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
- } \
+#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{0, 0, InLevel, std::string_view(fmtstr)}; \
+ zen::LoggerRef Logger = InLogger; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ if (Logger.ShouldLog(InLevel)) \
+ { \
+ zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+ } \
} while (false);
#define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \
@@ -147,13 +175,18 @@ using zen::Log;
#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr, ##__VA_ARGS__)
#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr, ##__VA_ARGS__)
-#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
+// Routes ZEN_INFO / ZEN_WARN / ZEN_DEBUG etc. in the enclosing scope through the given logger expression
+// (a LoggerRef or something convertible, e.g. a member or a context field) instead of the namespace default.
+// Expand at block scope; the resulting local `Log` shadows `zen::Log()` for the rest of the block.
+#define ZEN_SCOPED_LOG(Expr) auto Log = [&]() { return (Expr); }
+
+#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{0, 0, InLevel, std::string_view(fmtstr)}; \
+ ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
+ zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
} while (false)
#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__)
diff --git a/src/zencore/include/zencore/logging/broadcastsink.h b/src/zencore/include/zencore/logging/broadcastsink.h
index c2709d87c..474662888 100644
--- a/src/zencore/include/zencore/logging/broadcastsink.h
+++ b/src/zencore/include/zencore/logging/broadcastsink.h
@@ -17,7 +17,7 @@ namespace zen::logging {
/// sink is immediately visible to all of them. This is the recommended way
/// to manage "default" sinks that should be active on most loggers.
///
-/// Each child sink owns its own Formatter — BroadcastSink::SetFormatter() is
+/// Each child sink owns its own Formatter - BroadcastSink::SetFormatter() is
/// intentionally a no-op so that per-sink formatting is not accidentally
/// overwritten by registry-wide formatter changes.
class BroadcastSink : public Sink
@@ -63,7 +63,7 @@ public:
}
}
- /// No-op — child sinks manage their own formatters.
+ /// No-op - child sinks manage their own formatters.
void SetFormatter(std::unique_ptr<Formatter> /*InFormatter*/) override {}
void AddSink(SinkPtr InSink)
diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h
index 765aa59e3..1092e7095 100644
--- a/src/zencore/include/zencore/logging/helpers.h
+++ b/src/zencore/include/zencore/logging/helpers.h
@@ -116,7 +116,7 @@ ShortFilename(const char* Path)
inline std::string_view
LevelToShortString(LogLevel Level)
{
- return ToStringView(Level);
+ return ShortToStringView(Level);
}
inline std::string_view
diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h
index 4a777c71e..644af2730 100644
--- a/src/zencore/include/zencore/logging/logmsg.h
+++ b/src/zencore/include/zencore/logging/logmsg.h
@@ -7,49 +7,50 @@
#include <chrono>
#include <string_view>
+namespace zen {
+int GetCurrentThreadId();
+}
+
namespace zen::logging {
using LogClock = std::chrono::system_clock;
+/**
+ * This represents a single log event, with all the data needed to format and
+ * emit the final log message.
+ *
+ * LogMessage is what gets passed to Sinks, and it's what contains all the
+ * contextual information about the log event.
+ */
+
struct LogMessage
{
- LogMessage() = default;
-
LogMessage(const LogPoint& InPoint, std::string_view InLoggerName, std::string_view InPayload)
: m_LoggerName(InLoggerName)
- , m_Level(InPoint.Level)
, m_Time(LogClock::now())
- , m_Source(InPoint.Location)
+ , m_ThreadId(zen::GetCurrentThreadId())
, m_Payload(InPayload)
, m_Point(&InPoint)
{
}
- std::string_view GetPayload() const { return m_Payload; }
- int GetThreadId() const { return m_ThreadId; }
- LogClock::time_point GetTime() const { return m_Time; }
- LogLevel GetLevel() const { return m_Level; }
- std::string_view GetLoggerName() const { return m_LoggerName; }
- const SourceLocation& GetSource() const { return m_Source; }
- const LogPoint& GetLogPoint() const { return *m_Point; }
+ [[nodiscard]] std::string_view GetPayload() const { return m_Payload; }
+ [[nodiscard]] int GetThreadId() const { return m_ThreadId; }
+ [[nodiscard]] LogClock::time_point GetTime() const { return m_Time; }
+ [[nodiscard]] LogLevel GetLevel() const { return m_Point->Level; }
+ [[nodiscard]] std::string_view GetLoggerName() const { return m_LoggerName; }
+ [[nodiscard]] SourceLocation GetSource() const { return m_Point->Location(); }
+ [[nodiscard]] const LogPoint& GetLogPoint() const { return *m_Point; }
void SetThreadId(int InThreadId) { m_ThreadId = InThreadId; }
- void SetPayload(std::string_view InPayload) { m_Payload = InPayload; }
- void SetLoggerName(std::string_view InName) { m_LoggerName = InName; }
- void SetLevel(LogLevel InLevel) { m_Level = InLevel; }
void SetTime(LogClock::time_point InTime) { m_Time = InTime; }
- void SetSource(const SourceLocation& InSource) { m_Source = InSource; }
private:
- static constexpr LogPoint s_DefaultPoint{{}, Off, {}};
-
- std::string_view m_LoggerName;
- LogLevel m_Level = Off;
- std::chrono::system_clock::time_point m_Time;
- SourceLocation m_Source;
- std::string_view m_Payload;
- const LogPoint* m_Point = &s_DefaultPoint;
- int m_ThreadId = 0;
+ std::string_view m_LoggerName;
+ LogClock::time_point m_Time;
+ int m_ThreadId;
+ std::string_view m_Payload;
+ const LogPoint* m_Point;
};
} // namespace zen::logging
diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h
index d97c433fd..38a0bc14f 100644
--- a/src/zencore/include/zencore/mpscqueue.h
+++ b/src/zencore/include/zencore/mpscqueue.h
@@ -11,7 +11,7 @@
using std::hardware_constructive_interference_size;
using std::hardware_destructive_interference_size;
#else
-// 64 bytes on x86-64 │ L1_CACHE_BYTES │ L1_CACHE_SHIFT │ __cacheline_aligned │ ...
+// 64 bytes on x86-64 | L1_CACHE_BYTES | L1_CACHE_SHIFT | __cacheline_aligned | ...
constexpr std::size_t hardware_constructive_interference_size = 64;
constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h
index 5ae7fad68..fd24a6d7d 100644
--- a/src/zencore/include/zencore/process.h
+++ b/src/zencore/include/zencore/process.h
@@ -34,7 +34,7 @@ public:
/// Throws std::system_error on failure.
explicit ProcessHandle(int Pid);
- /// Construct from an existing native process handle. Takes ownership —
+ /// Construct from an existing native process handle. Takes ownership -
/// the caller must not close the handle afterwards. Windows only.
#if ZEN_PLATFORM_WINDOWS
explicit ProcessHandle(void* NativeHandle);
@@ -56,7 +56,7 @@ public:
/// Same as Initialize(int) but reports errors via @p OutEc instead of throwing.
void Initialize(int Pid, std::error_code& OutEc);
- /// Initialize from an existing native process handle. Takes ownership —
+ /// Initialize from an existing native process handle. Takes ownership -
/// the caller must not close the handle afterwards. Windows only.
#if ZEN_PLATFORM_WINDOWS
void Initialize(void* ProcessHandle);
@@ -174,13 +174,18 @@ struct CreateProcOptions
// allocated and no conhost.exe is spawned. Stdout/stderr still work when redirected
// via pipes. Prefer this for headless worker processes.
Flag_NoConsole = 1 << 3,
- // Create the child in a new process group (CREATE_NEW_PROCESS_GROUP on Windows).
- // Allows sending CTRL_BREAK_EVENT to the child group without affecting the parent.
- Flag_Windows_NewProcessGroup = 1 << 4,
+ // Spawn the child as a new process group leader (its pgid = its own pid).
+ // On Windows: CREATE_NEW_PROCESS_GROUP, enables CTRL_BREAK_EVENT targeting.
+ // On POSIX: child calls setpgid(0,0) / posix_spawn with POSIX_SPAWN_SETPGROUP+pgid=0.
+ // Mutually exclusive with ProcessGroupId > 0.
+ Flag_NewProcessGroup = 1 << 4,
// Allocate a hidden console for the child (CREATE_NO_WINDOW on Windows). Unlike
// Flag_NoConsole the child still gets a console (and a conhost.exe) but no visible
// window. Use this when the child needs a console for stdio but should not show a window.
Flag_NoWindow = 1 << 5,
+ // Launch the child at below-normal scheduling priority.
+ // On Windows: BELOW_NORMAL_PRIORITY_CLASS. On POSIX: nice(5).
+ Flag_BelowNormalPriority = 1 << 6,
};
const std::filesystem::path* WorkingDirectory = nullptr;
@@ -190,16 +195,16 @@ struct CreateProcOptions
StdoutPipeHandles* StderrPipe = nullptr; // Optional separate pipe for stderr. When null, stderr shares StdoutPipe.
/// Additional environment variables for the child process. These are merged
- /// with the parent's environment — existing variables are inherited, and
+ /// with the parent's environment - existing variables are inherited, and
/// entries here override or add to them.
std::vector<std::pair<std::string, std::string>> Environment;
#if ZEN_PLATFORM_WINDOWS
JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed
#else
- /// POSIX process group id. When > 0, the child is placed into this process
- /// group via setpgid() before exec. Use the pid of the first child as the
- /// pgid to create a group, then pass the same pgid for subsequent children.
+ /// When > 0, child joins this existing process group. Mutually exclusive with
+ /// Flag_NewProcessGroup; use that flag on the first spawn to create the group,
+ /// then pass the resulting pid here for subsequent spawns to join it.
int ProcessGroupId = 0;
#endif
};
diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h
index 3d4c19282..3183c7c0c 100644
--- a/src/zencore/include/zencore/sharedbuffer.h
+++ b/src/zencore/include/zencore/sharedbuffer.h
@@ -65,7 +65,7 @@ public:
private:
// This may be null, for a default constructed UniqueBuffer only
- RefPtr<IoBufferCore> m_Buffer;
+ Ref<IoBufferCore> m_Buffer;
friend class SharedBuffer;
};
@@ -81,7 +81,7 @@ public:
inline explicit SharedBuffer(IoBufferCore* Owner) : m_Buffer(Owner) {}
explicit SharedBuffer(IoBuffer&& Buffer) : m_Buffer(std::move(Buffer.m_Core)) {}
explicit SharedBuffer(const IoBuffer& Buffer) : m_Buffer(Buffer.m_Core) {}
- explicit SharedBuffer(RefPtr<IoBufferCore>&& Owner) : m_Buffer(std::move(Owner)) {}
+ explicit SharedBuffer(Ref<IoBufferCore>&& Owner) : m_Buffer(std::move(Owner)) {}
[[nodiscard]] const void* GetData() const
{
@@ -143,7 +143,7 @@ public:
/** Returns true if this points to a buffer owner. */
[[nodiscard]] inline explicit operator bool() const { return !IsNull(); }
- [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer); }
+ [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer.Get()); }
SharedBuffer& operator=(UniqueBuffer&& Rhs)
{
@@ -171,7 +171,7 @@ public:
[[nodiscard]] static SharedBuffer Clone(MemoryView View);
private:
- RefPtr<IoBufferCore> m_Buffer;
+ Ref<IoBufferCore> m_Buffer;
};
void sharedbuffer_forcelink();
diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h
index 60293a313..b4926070c 100644
--- a/src/zencore/include/zencore/string.h
+++ b/src/zencore/include/zencore/string.h
@@ -402,6 +402,12 @@ public:
inline std::string_view ToView() const { return std::string_view(m_Base, m_CurPos - m_Base); }
inline std::string ToString() const { return std::string{Data(), Size()}; }
+ /// Append a zero-padded decimal integer. MinWidth is the minimum number of digits (zero-padded on the left).
+ void AppendPaddedInt(int64_t Value, int MinWidth);
+
+ /// Append a single character repeated Count times.
+ void AppendFill(char C, size_t Count);
+
inline void AppendCodepoint(uint32_t cp)
{
if (cp < 0x80) // one octet
@@ -435,6 +441,24 @@ public:
}
};
+/// Output iterator adapter for StringBuilderBase, enabling direct use with fmt::format_to / fmt::format_to_n.
+class StringBuilderAppender
+{
+ StringBuilderBase* m_Builder;
+
+public:
+ explicit StringBuilderAppender(StringBuilderBase& Builder) : m_Builder(&Builder) {}
+
+ StringBuilderAppender& operator=(char C)
+ {
+ m_Builder->Append(C);
+ return *this;
+ }
+ StringBuilderAppender& operator*() { return *this; }
+ StringBuilderAppender& operator++() { return *this; }
+ StringBuilderAppender operator++(int) { return *this; }
+};
+
template<size_t N>
class StringBuilder : public StringBuilderBase
{
@@ -609,6 +633,17 @@ ParseHexBytes(std::string_view InputString, uint8_t* OutPtr)
return ParseHexBytes(InputString.data(), InputString.size(), OutPtr);
}
+/** Parse hex string into a byte buffer, validating that the hex string is exactly ExpectedByteCount * 2 characters. */
+inline bool
+ParseHexBytes(std::string_view InputString, uint8_t* OutPtr, size_t ExpectedByteCount)
+{
+ if (InputString.size() != ExpectedByteCount * 2)
+ {
+ return false;
+ }
+ return ParseHexBytes(InputString.data(), InputString.size(), OutPtr);
+}
+
inline void
ToHexBytes(const uint8_t* InputData, size_t ByteCount, char* OutString)
{
@@ -722,6 +757,32 @@ struct NiceNum : public NiceBase
inline NiceNum(uint64_t Num) { NiceNumToBuffer(Num, m_Buffer); }
};
+size_t ThousandsToBuffer(uint64_t Num, std::span<char> Buffer);
+
+/// Integer formatted with comma thousands separators (e.g. "1,234,567")
+struct ThousandsNum
+{
+ inline ThousandsNum(UnsignedIntegral auto Number) { ThousandsToBuffer(uint64_t(Number), m_Buffer); }
+ inline ThousandsNum(SignedIntegral auto Number)
+ {
+ if (Number < 0)
+ {
+ m_Buffer[0] = '-';
+ ThousandsToBuffer(uint64_t(-Number), std::span<char>(m_Buffer + 1, sizeof(m_Buffer) - 1));
+ }
+ else
+ {
+ ThousandsToBuffer(uint64_t(Number), m_Buffer);
+ }
+ }
+
+ inline const char* c_str() const { return m_Buffer; }
+ inline operator std::string_view() const { return std::string_view(m_Buffer); }
+
+private:
+ char m_Buffer[28]; // max uint64: "18,446,744,073,709,551,615" (26) + NUL + sign
+};
+
struct NiceBytes : public NiceBase
{
inline NiceBytes(uint64_t Num) { NiceBytesToBuffer(Num, m_Buffer); }
diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h
index 52dafc18b..efc9bb6d2 100644
--- a/src/zencore/include/zencore/system.h
+++ b/src/zencore/include/zencore/system.h
@@ -46,6 +46,11 @@ struct ExtendedSystemMetrics : SystemMetrics
SystemMetrics GetSystemMetrics();
+/// Lightweight query that only refreshes fields that change at runtime
+/// (available memory, uptime). Topology fields (CPU/core counts, total memory)
+/// are left at their default values and must be filled from a cached snapshot.
+void RefreshDynamicSystemMetrics(SystemMetrics& InOutMetrics);
+
void SetCpuCountForReporting(int FakeCpuCount);
SystemMetrics GetSystemMetricsForReporting();
diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h
index 2a789d18f..68461deb2 100644
--- a/src/zencore/include/zencore/testutils.h
+++ b/src/zencore/include/zencore/testutils.h
@@ -62,24 +62,24 @@ struct TrueType
namespace utf8test {
// 2-byte UTF-8 (Latin extended)
- static constexpr const char kLatin[] = u8"café_résumé";
- static constexpr const wchar_t kLatinW[] = L"café_résumé";
+ static constexpr const char kLatin[] = u8"caf\xC3\xA9_r\xC3\xA9sum\xC3\xA9";
+ static constexpr const wchar_t kLatinW[] = L"caf\u00E9_r\u00E9sum\u00E9";
// 2-byte UTF-8 (Cyrillic)
- static constexpr const char kCyrillic[] = u8"данные";
- static constexpr const wchar_t kCyrillicW[] = L"данные";
+ static constexpr const char kCyrillic[] = u8"\xD0\xB4\xD0\xB0\xD0\xBD\xD0\xBD\xD1\x8B\xD0\xB5";
+ static constexpr const wchar_t kCyrillicW[] = L"\u0434\u0430\u043D\u043D\u044B\u0435";
// 3-byte UTF-8 (CJK)
- static constexpr const char kCJK[] = u8"日本語";
- static constexpr const wchar_t kCJKW[] = L"日本語";
+ static constexpr const char kCJK[] = u8"\xE6\x97\xA5\xE6\x9C\xAC\xE8\xAA\x9E";
+ static constexpr const wchar_t kCJKW[] = L"\u65E5\u672C\u8A9E";
// Mixed scripts
- static constexpr const char kMixed[] = u8"zen_éд日";
- static constexpr const wchar_t kMixedW[] = L"zen_éд日";
+ static constexpr const char kMixed[] = u8"zen_\xC3\xA9\xD0\xB4\xE6\x97\xA5";
+ static constexpr const wchar_t kMixedW[] = L"zen_\u00E9\u0434\u65E5";
- // 4-byte UTF-8 (supplementary plane) — string tests only, NOT filesystem
- static constexpr const char kEmoji[] = u8"📦";
- static constexpr const wchar_t kEmojiW[] = L"📦";
+ // 4-byte UTF-8 (supplementary plane) - string tests only, NOT filesystem
+ static constexpr const char kEmoji[] = u8"\xF0\x9F\x93\xA6";
+ static constexpr const wchar_t kEmojiW[] = L"\U0001F4E6";
// BMP-only test strings suitable for filesystem use
static constexpr const char* kFilenameSafe[] = {kLatin, kCyrillic, kCJK, kMixed};
diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h
index 56ce5904b..0f7733df5 100644
--- a/src/zencore/include/zencore/thread.h
+++ b/src/zencore/include/zencore/thread.h
@@ -14,7 +14,7 @@
namespace zen {
-void SetCurrentThreadName(std::string_view ThreadName);
+void SetCurrentThreadName(std::string_view ThreadName, int32_t SortHint = 0);
/**
* Reader-writer lock
diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h
index a31950b0b..57c7e20fa 100644
--- a/src/zencore/include/zencore/zencore.h
+++ b/src/zencore/include/zencore/zencore.h
@@ -94,7 +94,7 @@ protected:
// With no extra args: ZEN_ASSERT_MSG_("expr") -> "expr"
// With a message arg: ZEN_ASSERT_MSG_("expr", "msg") -> "expr" ": " "msg"
// With fmt-style args: ZEN_ASSERT_MSG_("expr", "msg", args...) -> "expr" ": " "msg"
-// The extra fmt args are silently discarded here — use ZEN_ASSERT_FORMAT for those.
+// The extra fmt args are silently discarded here - use ZEN_ASSERT_FORMAT for those.
#define ZEN_ASSERT_MSG_SELECT_(_1, _2, N, ...) N
#define ZEN_ASSERT_MSG_1_(expr) expr
#define ZEN_ASSERT_MSG_2_(expr, msg, ...) expr ": " msg
diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp
index c47c54981..529afe341 100644
--- a/src/zencore/iobuffer.cpp
+++ b/src/zencore/iobuffer.cpp
@@ -107,7 +107,7 @@ IoBufferCore::~IoBufferCore()
}
void
-IoBufferCore::DeleteThis() const
+IoBufferCore::DeleteThis() const noexcept
{
// We do this just to avoid paying for the cost of a vtable
if (const IoBufferExtendedCore* _ = ExtendedCore())
@@ -210,7 +210,12 @@ IoBufferExtendedCore::~IoBufferExtendedCore()
// Mark file for deletion when final handle is closed
FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE};
- SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
+ if (!SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi))
+ {
+ ZEN_WARN("SetFileInformationByHandle(DeleteOnClose) failed for file handle {}, reason '{}'",
+ m_FileHandle,
+ GetLastErrorAsString());
+ }
#else
std::error_code Ec;
std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec);
@@ -447,7 +452,7 @@ GetNullBufferCore()
return Core;
}
-RefPtr<IoBufferCore> IoBuffer::NullBufferCore(GetNullBufferCore());
+Ref<IoBufferCore> IoBuffer::NullBufferCore(GetNullBufferCore());
IoBuffer::IoBuffer(size_t InSize) : m_Core(new IoBufferCore(InSize))
{
@@ -475,7 +480,7 @@ IoBuffer::IoBuffer(const IoBuffer& OuterBuffer, size_t Offset, size_t Size)
}
else
{
- m_Core = new IoBufferCore(OuterBuffer.m_Core, reinterpret_cast<const uint8_t*>(OuterBuffer.Data()) + Offset, Size);
+ m_Core = new IoBufferCore(OuterBuffer.m_Core.Get(), reinterpret_cast<const uint8_t*>(OuterBuffer.Data()) + Offset, Size);
}
}
diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp
index 3e58fb97d..a5a82717d 100644
--- a/src/zencore/jobqueue.cpp
+++ b/src/zencore/jobqueue.cpp
@@ -93,7 +93,7 @@ public:
{
NewJobId = IdGenerator.fetch_add(1);
}
- RefPtr<Job> NewJob(new Job());
+ Ref<Job> NewJob(new Job());
NewJob->Queue = this;
NewJob->Name = Name;
NewJob->Callback = std::move(JobFunc);
@@ -124,7 +124,7 @@ public:
QueueLock.WithExclusiveLock([&]() {
if (auto It = std::find_if(QueuedJobs.begin(),
QueuedJobs.end(),
- [NewJobId](const RefPtr<Job>& Job) { return Job->Id.Id == NewJobId; });
+ [NewJobId](const Ref<Job>& Job) { return Job->Id.Id == NewJobId; });
It != QueuedJobs.end())
{
QueuedJobs.erase(It);
@@ -156,7 +156,7 @@ public:
Result = true;
return;
}
- if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const RefPtr<Job>& Job) { return Job->Id.Id == Id.Id; });
+ if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const Ref<Job>& Job) { return Job->Id.Id == Id.Id; });
It != QueuedJobs.end())
{
ZEN_DEBUG("Cancelling queued background job {}:'{}'", (*It)->Id.Id, (*It)->Name);
@@ -301,7 +301,7 @@ public:
AbortedJobs.erase(It);
return;
}
- if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const RefPtr<Job>& Job) { return Job->Id.Id == Id.Id; });
+ if (auto It = std::find_if(QueuedJobs.begin(), QueuedJobs.end(), [&Id](const Ref<Job>& Job) { return Job->Id.Id == Id.Id; });
It != QueuedJobs.end())
{
Result = Convert(JobStatus::Queued, *(*It));
@@ -340,20 +340,20 @@ public:
std::atomic_uint64_t IdGenerator = 1;
- std::atomic_bool InitializedFlag = false;
- RwLock QueueLock;
- std::deque<RefPtr<Job>> QueuedJobs;
- std::unordered_map<uint64_t, RefPtr<Job>> RunningJobs;
- std::unordered_map<uint64_t, RefPtr<Job>> CompletedJobs;
- std::unordered_map<uint64_t, RefPtr<Job>> AbortedJobs;
+ std::atomic_bool InitializedFlag = false;
+ RwLock QueueLock;
+ std::deque<Ref<Job>> QueuedJobs;
+ std::unordered_map<uint64_t, Ref<Job>> RunningJobs;
+ std::unordered_map<uint64_t, Ref<Job>> CompletedJobs;
+ std::unordered_map<uint64_t, Ref<Job>> AbortedJobs;
WorkerThreadPool WorkerPool;
Latch WorkerCounter;
void Worker()
{
- int CurrentThreadId = GetCurrentThreadId();
- RefPtr<Job> CurrentJob;
+ int CurrentThreadId = GetCurrentThreadId();
+ Ref<Job> CurrentJob;
QueueLock.WithExclusiveLock([&]() {
if (!QueuedJobs.empty())
{
diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp
index 5ada0cac7..aa95db950 100644
--- a/src/zencore/logging.cpp
+++ b/src/zencore/logging.cpp
@@ -26,7 +26,7 @@ namespace {
// Bootstrap logger: a minimal stdout logger that exists for the entire lifetime
// of the process. TheDefaultLogger points here before InitializeLogging() runs
// (and is restored here after ShutdownLogging()) so that log macros always have
-// a usable target — no null checks or lazy init required on the common path.
+// a usable target - no null checks or lazy init required on the common path.
zen::Ref<zen::logging::Logger> s_BootstrapLogger = [] {
zen::logging::SinkPtr Sink(new zen::logging::AnsiColorStdoutSink());
return zen::Ref<zen::logging::Logger>(new zen::logging::Logger("", Sink));
@@ -112,6 +112,14 @@ constinit std::string_view LevelNames[] = {std::string_view("trace", 5),
std::string_view("critical", 8),
std::string_view("off", 3)};
+constinit std::string_view ShortNames[] = {std::string_view("trc", 3),
+ std::string_view("dbg", 3),
+ std::string_view("inf", 3),
+ std::string_view("wrn", 3),
+ std::string_view("err", 3),
+ std::string_view("crt", 3),
+ std::string_view("off", 3)};
+
LogLevel
ParseLogLevelString(std::string_view Name)
{
@@ -139,12 +147,27 @@ ParseLogLevelString(std::string_view Name)
std::string_view
ToStringView(LogLevel Level)
{
+ using namespace std::literals;
+
if (int(Level) < LogLevelCount)
{
return LevelNames[int(Level)];
}
- return "None";
+ return "None"sv;
+}
+
+std::string_view
+ShortToStringView(LogLevel Level)
+{
+ using namespace std::literals;
+
+ if (int(Level) < LogLevelCount)
+ {
+ return ShortNames[int(Level)];
+ }
+
+ return "None"sv;
}
} // namespace zen::logging
@@ -476,6 +499,10 @@ LoggerRef::LoggerRef(logging::Logger& InLogger) : m_Logger(static_cast<logging::
{
}
+LoggerRef::LoggerRef(std::string_view LogCategory) : m_Logger(zen::logging::Get(LogCategory).m_Logger)
+{
+}
+
void
LoggerRef::Flush()
{
diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp
index 03aae068a..fb127bede 100644
--- a/src/zencore/logging/ansicolorsink.cpp
+++ b/src/zencore/logging/ansicolorsink.cpp
@@ -5,6 +5,7 @@
#include <zencore/logging/messageonlyformatter.h>
#include <zencore/thread.h>
+#include <zencore/timer.h>
#include <cstdio>
#include <cstdlib>
@@ -22,48 +23,37 @@
namespace zen::logging {
-// Default formatter replicating spdlog's %+ pattern:
-// [YYYY-MM-DD HH:MM:SS.mmm] [logger_name] [level] message\n
+// Default formatter for console output:
+// [HH:MM:SS.mmm] [logger_name] [level] message\n
+// Timestamps show elapsed time since process launch.
class DefaultConsoleFormatter : public Formatter
{
public:
+ DefaultConsoleFormatter() : m_Epoch(std::chrono::system_clock::now() - std::chrono::milliseconds(GetTimeSinceProcessStart())) {}
+
void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
{
- // timestamp
- auto Secs = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch());
- if (Secs != m_LastLogSecs)
- {
- m_LastLogSecs = Secs;
- m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
- }
+ // Elapsed time since process launch
+ auto Elapsed = Msg.GetTime() - m_Epoch;
+ auto TotalSecs = std::chrono::duration_cast<std::chrono::seconds>(Elapsed);
+ int Count = static_cast<int>(TotalSecs.count());
+ int LogSecs = Count % 60;
+ Count /= 60;
+ int LogMins = Count % 60;
+ int LogHours = Count / 60;
Dest.push_back('[');
- helpers::AppendInt(m_CachedLocalTm.tm_year + 1900, Dest);
- Dest.push_back('-');
- helpers::Pad2(m_CachedLocalTm.tm_mon + 1, Dest);
- Dest.push_back('-');
- helpers::Pad2(m_CachedLocalTm.tm_mday, Dest);
- Dest.push_back(' ');
- helpers::Pad2(m_CachedLocalTm.tm_hour, Dest);
+ helpers::Pad2(LogHours, Dest);
Dest.push_back(':');
- helpers::Pad2(m_CachedLocalTm.tm_min, Dest);
+ helpers::Pad2(LogMins, Dest);
Dest.push_back(':');
- helpers::Pad2(m_CachedLocalTm.tm_sec, Dest);
+ helpers::Pad2(LogSecs, Dest);
Dest.push_back('.');
auto Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime());
helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest);
Dest.push_back(']');
Dest.push_back(' ');
- // logger name
- if (Msg.GetLoggerName().size() > 0)
- {
- Dest.push_back('[');
- helpers::AppendStringView(Msg.GetLoggerName(), Dest);
- Dest.push_back(']');
- Dest.push_back(' ');
- }
-
// level
Dest.push_back('[');
if (IsColorEnabled())
@@ -78,6 +68,25 @@ public:
Dest.push_back(']');
Dest.push_back(' ');
+ using namespace std::literals;
+
+ // logger name
+ if (Msg.GetLoggerName().size() > 0)
+ {
+ if (IsColorEnabled())
+ {
+ Dest.append("\033[97m"sv);
+ }
+ Dest.push_back('[');
+ helpers::AppendStringView(Msg.GetLoggerName(), Dest);
+ Dest.push_back(']');
+ if (IsColorEnabled())
+ {
+ Dest.append("\033[0m"sv);
+ }
+ Dest.push_back(' ');
+ }
+
// message (align continuation lines with the first line)
size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0;
size_t LinePrefixCount = Dest.size() - AnsiBytes;
@@ -128,8 +137,7 @@ public:
}
private:
- std::chrono::seconds m_LastLogSecs{0};
- std::tm m_CachedLocalTm{};
+ LogClock::time_point m_Epoch;
};
bool
@@ -197,7 +205,7 @@ IsColorTerminal()
// Windows console supports ANSI color by default in modern versions
return true;
#else
- // Unknown terminal — be conservative
+ // Unknown terminal - be conservative
return false;
#endif
}
diff --git a/src/zencore/logging/registry.cpp b/src/zencore/logging/registry.cpp
index 383a5d8ba..0f552aced 100644
--- a/src/zencore/logging/registry.cpp
+++ b/src/zencore/logging/registry.cpp
@@ -137,7 +137,7 @@ struct Registry::Impl
{
if (Pattern.find_first_of("*?") == std::string::npos)
{
- // Exact match — fast path via map lookup.
+ // Exact match - fast path via map lookup.
auto It = m_Loggers.find(Pattern);
if (It != m_Loggers.end())
{
@@ -146,7 +146,7 @@ struct Registry::Impl
}
else
{
- // Wildcard pattern — iterate all loggers.
+ // Wildcard pattern - iterate all loggers.
for (auto& [Name, CurLogger] : m_Loggers)
{
if (MatchLoggerPattern(Pattern, Name))
diff --git a/src/zencore/memory/memory.cpp b/src/zencore/memory/memory.cpp
index 9e19c5db7..8dbb04e64 100644
--- a/src/zencore/memory/memory.cpp
+++ b/src/zencore/memory/memory.cpp
@@ -44,10 +44,10 @@ InitGMalloc()
// when using sanitizers, we should use the default/ansi allocator
#if ZEN_OVERRIDE_NEW_DELETE
-# if ZEN_RPMALLOC_ENABLED
+# if ZEN_MIMALLOC_ENABLED
if (Malloc == MallocImpl::None)
{
- Malloc = MallocImpl::Rpmalloc;
+ Malloc = MallocImpl::Mimalloc;
}
# endif
#endif
diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp
index e7baa3f8e..66062df4d 100644
--- a/src/zencore/process.cpp
+++ b/src/zencore/process.cpp
@@ -28,6 +28,7 @@ ZEN_THIRD_PARTY_INCLUDES_START
# include <pthread.h>
# include <signal.h>
# include <sys/file.h>
+# include <sys/resource.h>
# include <sys/sem.h>
# include <sys/stat.h>
# include <sys/syscall.h>
@@ -37,7 +38,9 @@ ZEN_THIRD_PARTY_INCLUDES_START
#endif
#if ZEN_PLATFORM_MAC
+# include <crt_externs.h>
# include <libproc.h>
+# include <spawn.h>
# include <sys/types.h>
# include <sys/sysctl.h>
#endif
@@ -135,8 +138,68 @@ IsZombieProcess(int pid, std::error_code& OutEc)
}
return false;
}
+
+static char**
+GetEnviron()
+{
+ return *_NSGetEnviron();
+}
#endif // ZEN_PLATFORM_MAC
+#if ZEN_PLATFORM_LINUX
+static char**
+GetEnviron()
+{
+ return environ;
+}
+#endif // ZEN_PLATFORM_LINUX
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+// Holds a null-terminated envp array built by merging the current process environment with
+// a set of overrides. When Overrides is empty, Data points directly to environ (no allocation).
+// Must outlive any posix_spawn / execve call that receives Data.
+struct EnvpHolder
+{
+ char** Data = GetEnviron();
+
+ explicit EnvpHolder(const std::vector<std::pair<std::string, std::string>>& Overrides)
+ {
+ if (Overrides.empty())
+ {
+ return;
+ }
+ std::map<std::string, std::string> EnvMap;
+ for (char** E = GetEnviron(); *E; ++E)
+ {
+ std::string_view Entry(*E);
+ const size_t EqPos = Entry.find('=');
+ if (EqPos != std::string_view::npos)
+ {
+ EnvMap[std::string(Entry.substr(0, EqPos))] = std::string(Entry.substr(EqPos + 1));
+ }
+ }
+ for (const auto& [Key, Value] : Overrides)
+ {
+ EnvMap[Key] = Value;
+ }
+ for (const auto& [Key, Value] : EnvMap)
+ {
+ m_Strings.push_back(Key + "=" + Value);
+ }
+ for (std::string& S : m_Strings)
+ {
+ m_Ptrs.push_back(S.data());
+ }
+ m_Ptrs.push_back(nullptr);
+ Data = m_Ptrs.data();
+ }
+
+private:
+ std::vector<std::string> m_Strings;
+ std::vector<char*> m_Ptrs;
+};
+#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
//////////////////////////////////////////////////////////////////////////
// Pipe creation for child process stdout capture
@@ -444,7 +507,7 @@ ProcessHandle::Kill()
std::error_code Ec;
if (!Wait(5000, Ec))
{
- // Graceful shutdown timed out — force-kill
+ // Graceful shutdown timed out - force-kill
kill(pid_t(m_Pid), SIGKILL);
Wait(1000, Ec);
}
@@ -691,6 +754,7 @@ BuildArgV(std::vector<char*>& Out, char* CommandLine)
++Cursor;
}
}
+
#endif // !WINDOWS || TESTS
#if ZEN_PLATFORM_WINDOWS
@@ -766,10 +830,14 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma
{
CreationFlags |= CREATE_NO_WINDOW;
}
- if (Options.Flags & CreateProcOptions::Flag_Windows_NewProcessGroup)
+ if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup)
{
CreationFlags |= CREATE_NEW_PROCESS_GROUP;
}
+ if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority)
+ {
+ CreationFlags |= BELOW_NORMAL_PRIORITY_CLASS;
+ }
if (AssignToJob)
{
CreationFlags |= CREATE_SUSPENDED;
@@ -980,6 +1048,10 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C
{
CreateProcFlags |= CREATE_NO_WINDOW;
}
+ if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority)
+ {
+ CreateProcFlags |= BELOW_NORMAL_PRIORITY_CLASS;
+ }
if (AssignToJob)
{
CreateProcFlags |= CREATE_SUSPENDED;
@@ -1070,23 +1142,30 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine
}
return CreateProcNormal(Executable, CommandLine, Options);
-#else
+#elif ZEN_PLATFORM_LINUX
+ // vfork uses CLONE_VM|CLONE_VFORK: the child shares the parent's address space and the
+ // parent is suspended until the child calls exec or _exit. This avoids page-table duplication
+ // and the ENOMEM that fork() produces on systems with strict overcommit (vm.overcommit_memory=2).
+ // All child-side setup uses only syscalls that do not modify user-space memory.
+ // Environment overrides are merged into envp before vfork so that setenv() is never called
+ // from the child (which would corrupt the shared address space).
std::vector<char*> ArgV;
std::string CommandLineZ(CommandLine);
BuildArgV(ArgV, CommandLineZ.data());
ArgV.push_back(nullptr);
- int ChildPid = fork();
+ EnvpHolder Envp(Options.Environment);
+
+ int ChildPid = vfork();
if (ChildPid < 0)
{
- ThrowLastError("Failed to fork a new child process");
+ ThrowLastError("Failed to vfork a new child process");
}
else if (ChildPid == 0)
{
if (Options.WorkingDirectory != nullptr)
{
- int Result = chdir(Options.WorkingDirectory->c_str());
- ZEN_UNUSED(Result);
+ chdir(Options.WorkingDirectory->c_str());
}
if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0)
@@ -1118,23 +1197,109 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine
}
}
- if (Options.ProcessGroupId > 0)
+ if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup)
+ {
+ setpgid(0, 0);
+ }
+ else if (Options.ProcessGroupId > 0)
{
setpgid(0, Options.ProcessGroupId);
}
- for (const auto& [Key, Value] : Options.Environment)
+ execve(Executable.c_str(), ArgV.data(), Envp.Data);
+ _exit(127);
+ }
+
+ if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority)
+ {
+ setpriority(PRIO_PROCESS, ChildPid, 5);
+ }
+
+ return ChildPid;
+#else // macOS
+ std::vector<char*> ArgV;
+ std::string CommandLineZ(CommandLine);
+ BuildArgV(ArgV, CommandLineZ.data());
+ ArgV.push_back(nullptr);
+
+ posix_spawn_file_actions_t FileActions;
+ posix_spawnattr_t Attr;
+
+ int Err = posix_spawn_file_actions_init(&FileActions);
+ if (Err != 0)
+ {
+ ThrowSystemError(Err, "posix_spawn_file_actions_init failed");
+ }
+ auto FileActionsGuard = MakeGuard([&] { posix_spawn_file_actions_destroy(&FileActions); });
+
+ Err = posix_spawnattr_init(&Attr);
+ if (Err != 0)
+ {
+ ThrowSystemError(Err, "posix_spawnattr_init failed");
+ }
+ auto AttrGuard = MakeGuard([&] { posix_spawnattr_destroy(&Attr); });
+
+ if (Options.WorkingDirectory != nullptr)
+ {
+ Err = posix_spawn_file_actions_addchdir_np(&FileActions, Options.WorkingDirectory->c_str());
+ if (Err != 0)
{
- setenv(Key.c_str(), Value.c_str(), 1);
+ ThrowSystemError(Err, "posix_spawn_file_actions_addchdir_np failed");
}
+ }
+
+ if (Options.StdoutPipe != nullptr && Options.StdoutPipe->WriteFd >= 0)
+ {
+ const int StdoutWriteFd = Options.StdoutPipe->WriteFd;
+ ZEN_ASSERT(StdoutWriteFd > STDERR_FILENO);
+ posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDOUT_FILENO);
- if (execv(Executable.c_str(), ArgV.data()) < 0)
+ if (Options.StderrPipe != nullptr && Options.StderrPipe->WriteFd >= 0)
{
- ThrowLastError("Failed to exec() a new process image");
+ const int StderrWriteFd = Options.StderrPipe->WriteFd;
+ ZEN_ASSERT(StderrWriteFd > STDERR_FILENO && StderrWriteFd != StdoutWriteFd);
+ posix_spawn_file_actions_adddup2(&FileActions, StderrWriteFd, STDERR_FILENO);
+ posix_spawn_file_actions_addclose(&FileActions, StderrWriteFd);
}
+ else
+ {
+ posix_spawn_file_actions_adddup2(&FileActions, StdoutWriteFd, STDERR_FILENO);
+ }
+
+ posix_spawn_file_actions_addclose(&FileActions, StdoutWriteFd);
+ }
+ else if (!Options.StdoutFile.empty())
+ {
+ posix_spawn_file_actions_addopen(&FileActions, STDOUT_FILENO, Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
+ posix_spawn_file_actions_adddup2(&FileActions, STDOUT_FILENO, STDERR_FILENO);
}
- return ChildPid;
+ if (Options.Flags & CreateProcOptions::Flag_NewProcessGroup)
+ {
+ posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP);
+ posix_spawnattr_setpgroup(&Attr, 0);
+ }
+ else if (Options.ProcessGroupId > 0)
+ {
+ posix_spawnattr_setflags(&Attr, POSIX_SPAWN_SETPGROUP);
+ posix_spawnattr_setpgroup(&Attr, Options.ProcessGroupId);
+ }
+
+ EnvpHolder Envp(Options.Environment);
+
+ pid_t ChildPid = 0;
+ Err = posix_spawn(&ChildPid, Executable.c_str(), &FileActions, &Attr, ArgV.data(), Envp.Data);
+ if (Err != 0)
+ {
+ ThrowSystemError(Err, "Failed to posix_spawn a new child process");
+ }
+
+ if (Options.Flags & CreateProcOptions::Flag_BelowNormalPriority)
+ {
+ setpriority(PRIO_PROCESS, ChildPid, 5);
+ }
+
+ return int(ChildPid);
#endif
}
@@ -1252,14 +1417,28 @@ JobObject::Initialize()
}
JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {};
- LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
+ LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION;
if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo)))
{
ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError());
CloseHandle(m_JobHandle);
m_JobHandle = nullptr;
+ return;
}
+
+ // Prevent child processes from clearing SEM_NOGPFAULTERRORBOX, which
+ // suppresses WER/Dr. Watson crash dialogs. Without this, a crashing
+ // child can pop a modal dialog and block the monitor thread.
+# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE)
+# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400
+# endif
+ JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{};
+ UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE;
+ SetInformationJobObject(m_JobHandle, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions));
+
+ // Set error mode on the current process so children inherit it.
+ SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX);
}
bool
@@ -1576,7 +1755,7 @@ GetProcessCommandLine(int Pid, std::error_code& OutEc)
++p; // skip null terminator of argv[0]
}
- // Build result: remaining entries joined by spaces (inter-arg nulls → spaces)
+ // Build result: remaining entries joined by spaces (inter-arg nulls -> spaces)
std::string Result;
Result.reserve(static_cast<size_t>(End - p));
for (const char* q = p; q < End; ++q)
@@ -1914,7 +2093,7 @@ GetProcessMetrics(const ProcessHandle& Handle, ProcessMetrics& OutMetrics)
{
Buf[Len] = '\0';
- // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens
+ // Skip past "pid (name) " - find last ')' to handle names containing spaces or parens
const char* P = strrchr(Buf, ')');
if (P)
{
diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp
index f19afe715..674b154e0 100644
--- a/src/zencore/refcount.cpp
+++ b/src/zencore/refcount.cpp
@@ -35,29 +35,29 @@ refcount_forcelink()
TEST_SUITE_BEGIN("core.refcount");
-TEST_CASE("RefPtr")
+TEST_CASE("Ref")
{
- RefPtr<TestRefClass> Ref;
- Ref = new TestRefClass;
+ Ref<TestRefClass> RefA;
+ RefA = new TestRefClass;
bool IsDestroyed = false;
- Ref->OnDestroy = [&] { IsDestroyed = true; };
+ RefA->OnDestroy = [&] { IsDestroyed = true; };
CHECK(IsDestroyed == false);
- CHECK(Ref->RefCount() == 1);
+ CHECK(RefA->RefCount() == 1);
- RefPtr<TestRefClass> Ref2;
- Ref2 = Ref;
+ Ref<TestRefClass> RefB;
+ RefB = RefA;
CHECK(IsDestroyed == false);
- CHECK(Ref->RefCount() == 2);
+ CHECK(RefA->RefCount() == 2);
- RefPtr<TestRefClass> Ref3;
- Ref2 = Ref3;
+ Ref<TestRefClass> RefC;
+ RefB = RefC;
CHECK(IsDestroyed == false);
- CHECK(Ref->RefCount() == 1);
- Ref = Ref3;
+ CHECK(RefA->RefCount() == 1);
+ RefA = RefC;
CHECK(IsDestroyed == true);
}
diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp
index 8491bef64..7e3f33191 100644
--- a/src/zencore/sentryintegration.cpp
+++ b/src/zencore/sentryintegration.cpp
@@ -250,7 +250,7 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine
if (SentryOptions == nullptr)
{
- // OOM — skip sentry entirely rather than crashing on the subsequent set calls
+ // OOM - skip sentry entirely rather than crashing on the subsequent set calls
m_SentryErrorCode = -1;
m_IsInitialized = true;
return;
diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp
index 8dc6d49d8..48730e670 100644
--- a/src/zencore/sharedbuffer.cpp
+++ b/src/zencore/sharedbuffer.cpp
@@ -100,7 +100,7 @@ SharedBuffer::MakeView(MemoryView View, SharedBuffer OuterBuffer)
return OuterBuffer;
}
- IoBufferCore* NewCore = new IoBufferCore(OuterBuffer.m_Buffer, View.GetData(), View.GetSize());
+ IoBufferCore* NewCore = new IoBufferCore(OuterBuffer.m_Buffer.Get(), View.GetData(), View.GetSize());
NewCore->SetIsImmutable(true);
return SharedBuffer(NewCore);
}
diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp
index 358722b0b..44f78aa75 100644
--- a/src/zencore/string.cpp
+++ b/src/zencore/string.cpp
@@ -381,6 +381,34 @@ NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format)
}
size_t
+ThousandsToBuffer(uint64_t Num, std::span<char> Buffer)
+{
+ // Format into a temporary buffer without separators
+ char Tmp[24];
+ int Len = snprintf(Tmp, sizeof(Tmp), "%llu", (unsigned long long)Num);
+
+ // Insert comma separators
+ int SepCount = (Len - 1) / 3;
+ int TotalLen = Len + SepCount;
+ ZEN_ASSERT(TotalLen < (int)Buffer.size());
+
+ int Src = Len - 1;
+ int Dst = TotalLen;
+ Buffer[Dst--] = '\0';
+
+ for (int i = 0; Src >= 0; i++)
+ {
+ if (i > 0 && i % 3 == 0)
+ {
+ Buffer[Dst--] = ',';
+ }
+ Buffer[Dst--] = Tmp[Src--];
+ }
+
+ return TotalLen;
+}
+
+size_t
NiceNumToBuffer(uint64_t Num, std::span<char> Buffer)
{
return NiceNumGeneral(Num, Buffer, kNicenum1024);
@@ -515,6 +543,40 @@ template class StringBuilderImpl<wchar_t>;
//////////////////////////////////////////////////////////////////////////
void
+StringBuilderBase::AppendPaddedInt(int64_t Value, int MinWidth)
+{
+ char Buf[24];
+ char* End = Buf + sizeof(Buf);
+ char* Ptr = End;
+ bool Negative = Value < 0;
+ uint64_t Abs = Negative ? uint64_t(-Value) : uint64_t(Value);
+ do
+ {
+ *--Ptr = '0' + char(Abs % 10);
+ Abs /= 10;
+ } while (Abs > 0);
+ while ((End - Ptr) < MinWidth)
+ {
+ *--Ptr = '0';
+ }
+ if (Negative)
+ {
+ *--Ptr = '-';
+ }
+ AppendRange(Ptr, End);
+}
+
+void
+StringBuilderBase::AppendFill(char C, size_t Count)
+{
+ EnsureCapacity(Count);
+ std::memset(m_CurPos, C, Count);
+ m_CurPos += Count;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
UrlDecode(std::string_view InUrl, StringBuilderBase& OutUrl)
{
std::string_view::size_type i = 0;
diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp
index 6909e1a9b..486050d83 100644
--- a/src/zencore/system.cpp
+++ b/src/zencore/system.cpp
@@ -148,6 +148,18 @@ GetSystemMetrics()
return Metrics;
}
+void
+RefreshDynamicSystemMetrics(SystemMetrics& Metrics)
+{
+ MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)};
+ GlobalMemoryStatusEx(&MemStatus);
+
+ Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024;
+ Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024;
+ Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024;
+ Metrics.UptimeSeconds = GetTickCount64() / 1000;
+}
+
std::vector<std::string>
GetLocalIpAddresses()
{
@@ -324,6 +336,51 @@ GetSystemMetrics()
return Metrics;
}
+
+void
+RefreshDynamicSystemMetrics(SystemMetrics& Metrics)
+{
+ long PageSize = sysconf(_SC_PAGE_SIZE);
+ long AvailPages = sysconf(_SC_AVPHYS_PAGES);
+
+ if (AvailPages > 0 && PageSize > 0)
+ {
+ Metrics.AvailSystemMemoryMiB = (AvailPages * PageSize) / 1024 / 1024;
+ Metrics.AvailVirtualMemoryMiB = Metrics.AvailSystemMemoryMiB;
+ }
+
+ if (FILE* UptimeFile = fopen("/proc/uptime", "r"))
+ {
+ double UptimeSec = 0;
+ if (fscanf(UptimeFile, "%lf", &UptimeSec) == 1)
+ {
+ Metrics.UptimeSeconds = static_cast<uint64_t>(UptimeSec);
+ }
+ fclose(UptimeFile);
+ }
+
+ if (FILE* MemInfo = fopen("/proc/meminfo", "r"))
+ {
+ char Line[256];
+ long SwapFree = 0;
+
+ while (fgets(Line, sizeof(Line), MemInfo))
+ {
+ if (strncmp(Line, "SwapFree:", 9) == 0)
+ {
+ sscanf(Line, "SwapFree: %ld kB", &SwapFree);
+ break;
+ }
+ }
+ fclose(MemInfo);
+
+ if (SwapFree > 0)
+ {
+ Metrics.AvailPageFileMiB = SwapFree / 1024;
+ }
+ }
+}
+
#elif ZEN_PLATFORM_MAC
std::string
GetMachineName()
@@ -398,6 +455,36 @@ GetSystemMetrics()
return Metrics;
}
+
+void
+RefreshDynamicSystemMetrics(SystemMetrics& Metrics)
+{
+ vm_size_t PageSize = 0;
+ host_page_size(mach_host_self(), &PageSize);
+
+ vm_statistics64_data_t VmStats;
+ mach_msg_type_number_t InfoCount = sizeof(VmStats) / sizeof(natural_t);
+ host_statistics64(mach_host_self(), HOST_VM_INFO64, (host_info64_t)&VmStats, &InfoCount);
+
+ uint64_t FreeMemory = (uint64_t)(VmStats.free_count + VmStats.inactive_count) * PageSize;
+ Metrics.AvailSystemMemoryMiB = FreeMemory / 1024 / 1024;
+ Metrics.AvailVirtualMemoryMiB = Metrics.VirtualMemoryMiB;
+
+ xsw_usage SwapUsage;
+ size_t Size = sizeof(SwapUsage);
+ sysctlbyname("vm.swapusage", &SwapUsage, &Size, nullptr, 0);
+ Metrics.AvailPageFileMiB = (SwapUsage.xsu_total - SwapUsage.xsu_used) / 1024 / 1024;
+
+ struct timeval BootTime
+ {
+ };
+ Size = sizeof(BootTime);
+ if (sysctlbyname("kern.boottime", &BootTime, &Size, nullptr, 0) == 0)
+ {
+ Metrics.UptimeSeconds = static_cast<uint64_t>(time(nullptr) - BootTime.tv_sec);
+ }
+}
+
#else
# error "Unknown platform"
#endif
@@ -655,11 +742,16 @@ struct SystemMetricsTracker::Impl
std::mutex Mutex;
CpuSampler Sampler;
+ SystemMetrics CachedMetrics;
float CachedCpuPercent = 0.0f;
Clock::time_point NextSampleTime = Clock::now();
std::chrono::milliseconds MinInterval;
- explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) {}
+ explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval)
+ {
+ // Capture topology and total memory once; these don't change at runtime
+ CachedMetrics = GetSystemMetrics();
+ }
float SampleCpu()
{
@@ -683,7 +775,8 @@ ExtendedSystemMetrics
SystemMetricsTracker::Query()
{
ExtendedSystemMetrics Metrics;
- static_cast<SystemMetrics&>(Metrics) = GetSystemMetrics();
+ static_cast<SystemMetrics&>(Metrics) = m_Impl->CachedMetrics;
+ RefreshDynamicSystemMetrics(Metrics);
std::lock_guard Lock(m_Impl->Mutex);
Metrics.CpuUsagePercent = m_Impl->SampleCpu();
diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp
index 9f88a3365..67285dcf1 100644
--- a/src/zencore/testing.cpp
+++ b/src/zencore/testing.cpp
@@ -39,7 +39,7 @@ PrintCrashCallstack([[maybe_unused]] const char* SignalName)
// Use write() + backtrace_symbols_fd() which are async-signal-safe
write(STDERR_FILENO, "\n*** Caught ", 12);
write(STDERR_FILENO, SignalName, strlen(SignalName));
- write(STDERR_FILENO, " — callstack:\n", 15);
+ write(STDERR_FILENO, " - callstack:\n", 15);
void* Frames[64];
int FrameCount = backtrace(Frames, 64);
diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp
index c9908aec8..44446bd40 100644
--- a/src/zencore/testutils.cpp
+++ b/src/zencore/testutils.cpp
@@ -30,11 +30,15 @@ ScopedTemporaryDirectory::ScopedTemporaryDirectory() : m_RootPath(CreateTemporar
{
}
-ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory) : m_RootPath(Directory)
+ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory)
+: m_RootPath(Directory.empty() ? CreateTemporaryDirectory() : Directory)
{
- std::error_code Ec;
- DeleteDirectories(Directory, Ec);
- CreateDirectories(Directory);
+ if (!Directory.empty())
+ {
+ std::error_code Ec;
+ DeleteDirectories(Directory, Ec);
+ CreateDirectories(Directory);
+ }
}
ScopedTemporaryDirectory::~ScopedTemporaryDirectory()
diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp
index 067e66c0d..fd72afaa7 100644
--- a/src/zencore/thread.cpp
+++ b/src/zencore/thread.cpp
@@ -99,7 +99,7 @@ SetNameInternal(DWORD thread_id, const char* name)
#endif
void
-SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName)
+SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName, [[maybe_unused]] int32_t SortHint)
{
constexpr std::string_view::size_type MaxThreadNameLength = 255;
std::string_view LimitedThreadName = ThreadName.substr(0, MaxThreadNameLength);
@@ -108,7 +108,7 @@ SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName)
const int ThreadId = GetCurrentThreadId();
#if ZEN_WITH_TRACE
- trace::ThreadRegister(ThreadNameZ.c_str(), /* system id */ ThreadId, /* sort id */ 0);
+ trace::ThreadRegister(ThreadNameZ.c_str(), /* system id */ ThreadId, /* sort id */ SortHint);
#endif // ZEN_WITH_TRACE
#if ZEN_PLATFORM_WINDOWS
diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp
index 7c195e69f..d7084bbd1 100644
--- a/src/zencore/trace.cpp
+++ b/src/zencore/trace.cpp
@@ -6,6 +6,7 @@
# include <zencore/zencore.h>
# include <zencore/commandline.h>
# include <zencore/string.h>
+# include <zencore/thread.h>
# include <zencore/logging.h>
# define TRACE_IMPLEMENT 1
@@ -121,7 +122,7 @@ TraceInit(std::string_view ProgramName)
const char* CommandLineString = "";
# endif
- trace::ThreadRegister("main", /* system id */ 0, /* sort id */ 0);
+ trace::ThreadRegister("main", /* system id */ GetCurrentThreadId(), /* sort id */ -1);
trace::DescribeSession(ProgramName,
# if ZEN_BUILD_DEBUG
trace::Build::Debug,
diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp
index 8c29a8962..c1ac63621 100644
--- a/src/zencore/zencore.cpp
+++ b/src/zencore/zencore.cpp
@@ -273,7 +273,7 @@ zencore_forcelinktests()
zen::uid_forcelink();
zen::uson_forcelink();
zen::usonbuilder_forcelink();
- zen::usonpackage_forcelink();
+ zen::cbpackage_forcelink();
zen::cbjson_forcelink();
zen::cbyaml_forcelink();
zen::workthreadpool_forcelink();
diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md
new file mode 100644
index 000000000..13beaa968
--- /dev/null
+++ b/src/zenhorde/README.md
@@ -0,0 +1,17 @@
+# Horde Compute integration
+
+Zen compute can use Horde to provision runner nodes.
+
+## Launch a coordinator instance
+
+Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to.
+
+```bash
+zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio
+```
+
+## Use a coordinator
+
+```bash
+zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558
+```
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp
index 819b2d0cb..029b98e55 100644
--- a/src/zenhorde/hordeagent.cpp
+++ b/src/zenhorde/hordeagent.cpp
@@ -8,290 +8,479 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <cstring>
-#include <unordered_map>
namespace zen::horde {
-HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
-{
- ZEN_TRACE_CPU("HordeAgent::Connect");
+// --- AsyncHordeAgent ---
- auto Transport = std::make_unique<TcpComputeTransport>(Info);
- if (!Transport->IsValid())
+static const char*
+GetStateName(AsyncHordeAgent::State S)
+{
+ switch (S)
{
- ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
- return;
+ case AsyncHordeAgent::State::Idle:
+ return "idle";
+ case AsyncHordeAgent::State::Connecting:
+ return "connect";
+ case AsyncHordeAgent::State::WaitAgentAttach:
+ return "agent-attach";
+ case AsyncHordeAgent::State::SentFork:
+ return "fork";
+ case AsyncHordeAgent::State::WaitChildAttach:
+ return "child-attach";
+ case AsyncHordeAgent::State::Uploading:
+ return "upload";
+ case AsyncHordeAgent::State::Executing:
+ return "execute";
+ case AsyncHordeAgent::State::Polling:
+ return "poll";
+ case AsyncHordeAgent::State::Done:
+ return "done";
+ default:
+ return "unknown";
}
+}
- // The 64-byte nonce is always sent unencrypted as the first thing on the wire.
- // The Horde agent uses this to identify which lease this connection belongs to.
- Transport->Send(Info.Nonce, sizeof(Info.Nonce));
+AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async"))
+{
+}
- std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
- if (Info.EncryptionMode == Encryption::AES)
+AsyncHordeAgent::~AsyncHordeAgent()
+{
+ Cancel();
+}
+
+void
+AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone)
+{
+ m_Config = std::move(Config);
+ m_OnDone = std::move(OnDone);
+ m_State = State::Connecting;
+ DoConnect();
+}
+
+void
+AsyncHordeAgent::Cancel()
+{
+ m_Cancelled = true;
+ if (m_Socket)
{
- FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
- if (!FinalTransport->IsValid())
- {
- ZEN_WARN("failed to create AES transport");
- return;
- }
+ m_Socket->Close();
+ }
+ else if (m_TcpTransport)
+ {
+ // Cancelled before handshake completed - tear down the pending TCP connect.
+ m_TcpTransport->Close();
}
+}
- // Create multiplexed socket and channels
- m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));
+void
+AsyncHordeAgent::DoConnect()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect");
- // Channel 0 is the agent control channel (handles Attach/Fork handshake).
- // Channel 100 is the child I/O channel (handles file upload and remote execution).
- Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
- Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);
+ m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext);
+
+ auto Self = shared_from_this();
+ m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); });
+}
- if (!AgentComputeChannel || !ChildComputeChannel)
+void
+AsyncHordeAgent::OnConnected(const std::error_code& Ec)
+{
+ if (Ec || m_Cancelled)
{
- ZEN_WARN("failed to create compute channels");
+ if (Ec)
+ {
+ ZEN_WARN("connect failed: {}", Ec.message());
+ }
+ Finish(false);
return;
}
- m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
- m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));
+ // Optionally wrap with AES encryption
+ std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport);
+ if (m_Config.Machine.EncryptionMode == Encryption::AES)
+ {
+ FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext);
+ }
+
+ // Create the multiplexed socket and register channels. Ownership of the transport
+ // moves into the socket here - no need to retain a separate m_Transport field.
+ m_Socket = std::make_shared<AsyncComputeSocket>(std::move(FinalTransport), m_IoContext);
+
+ m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext);
+ m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext);
+
+ m_Socket->RegisterChannel(
+ 0,
+ [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); },
+ [this]() { m_AgentChannel->OnDetach(); });
- m_IsValid = true;
+ m_Socket->RegisterChannel(
+ 100,
+ [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); },
+ [this]() { m_ChildChannel->OnDetach(); });
+
+ m_Socket->StartRecvPump();
+
+ m_State = State::WaitAgentAttach;
+ DoWaitAgentAttach();
}
-HordeAgent::~HordeAgent()
+void
+AsyncHordeAgent::DoWaitAgentAttach()
{
- CloseConnection();
+ auto Self = shared_from_this();
+ m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnAgentResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::BeginCommunication()
+void
+AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
{
- ZEN_TRACE_CPU("HordeAgent::BeginCommunication");
-
- if (!m_IsValid)
+ if (m_Cancelled)
{
- return false;
+ Finish(false);
+ return;
}
- // Start the send/recv pump threads
- m_Socket->StartCommunication();
-
- // Wait for Attach on agent channel
- AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on agent channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- // Fork tells the remote agent to create child channel 100 with a 4MB buffer.
- // After this, the agent will send an Attach on the child channel.
+ m_State = State::SentFork;
+ DoSendFork();
+}
+
+void
+AsyncHordeAgent::DoSendFork()
+{
m_AgentChannel->Fork(100, 4 * 1024 * 1024);
- // Wait for Attach on child channel
- Type = m_ChildChannel->ReadResponse(5000);
+ m_State = State::WaitChildAttach;
+ DoWaitChildAttach();
+}
+
+void
+AsyncHordeAgent::DoWaitChildAttach()
+{
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnChildAttachResponse(Type, Data, Size);
+ });
+}
+
+void
+AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on child channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- return true;
+ m_State = State::Uploading;
+ m_CurrentBundleIndex = 0;
+ DoUploadNext();
}
-bool
-HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
+void
+AsyncHordeAgent::DoUploadNext()
{
- ZEN_TRACE_CPU("HordeAgent::UploadBinaries");
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
+ if (m_CurrentBundleIndex >= m_Config.Bundles.size())
+ {
+ // All bundles uploaded - proceed to execute
+ m_State = State::Executing;
+ DoExecute();
+ return;
+ }
- m_ChildChannel->UploadFiles("", BundleLocator.c_str());
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ m_ChildChannel->UploadFiles("", Locator.c_str());
- std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;
+ // Enter the ReadBlob/Blob upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
- std::string Key(Locator);
+void
+AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
+ if (Type == AgentMessageType::None)
+ {
+ if (m_ChildChannel->IsDetached())
{
- return It->second.get();
+ ZEN_WARN("connection lost during upload");
+ Finish(false);
+ return;
}
+ // Timeout - retry read
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+ return;
+ }
- const std::filesystem::path Path = BundleDir / (Key + ".blob");
- std::error_code Ec;
- auto File = std::make_unique<BasicFile>();
- File->Open(Path, BasicFile::Mode::kRead, Ec);
+ if (Type == AgentMessageType::WriteFilesResponse)
+ {
+ // This bundle upload is done - move to next
+ ++m_CurrentBundleIndex;
+ DoUploadNext();
+ return;
+ }
- if (Ec)
+ if (Type == AgentMessageType::Exception)
+ {
+ ExceptionInfo Ex;
+ if (!AsyncAgentMessageChannel::ReadException(Data, Size, Ex))
{
- ZEN_ERROR("cannot read blob file: '{}'", Path);
- return nullptr;
+ ZEN_ERROR("malformed Exception message during upload (size={})", Size);
+ Finish(false);
+ return;
}
+ ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
+ Finish(false);
+ return;
+ }
- BasicFile* Ptr = File.get();
- BlobFiles.emplace(std::move(Key), std::move(File));
- return Ptr;
- };
-
- // The upload protocol is request-driven: we send WriteFiles, then the remote agent
- // sends ReadBlob requests for each blob it needs. We respond with Blob data until
- // the agent sends WriteFilesResponse indicating the upload is complete.
- constexpr int32_t ReadResponseTimeoutMs = 1000;
+ if (Type != AgentMessageType::ReadBlob)
+ {
+ ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
+ Finish(false);
+ return;
+ }
- for (;;)
+ // Handle ReadBlob request
+ BlobRequest Req;
+ if (!AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req))
{
- bool TimedOut = false;
+ ZEN_ERROR("malformed ReadBlob message during upload (size={})", Size);
+ Finish(false);
+ return;
+ }
- if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
- {
- if (TimedOut)
- {
- continue;
- }
- // End of stream - check if it was a successful upload
- if (Type == AgentMessageType::WriteFilesResponse)
- {
- return true;
- }
- else if (Type == AgentMessageType::Exception)
- {
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
- }
- else
- {
- ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
- }
- return false;
- }
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob");
- BlobRequest Req;
- m_ChildChannel->ReadBlobRequest(Req);
+ std::error_code FsEc;
+ BasicFile File;
+ File.Open(BlobPath, BasicFile::Mode::kRead, FsEc);
- BasicFile* File = FindOrOpenBlob(Req.Locator);
- if (!File)
- {
- return false;
- }
+ if (FsEc)
+ {
+ ZEN_ERROR("cannot read blob file: '{}'", BlobPath);
+ Finish(false);
+ return;
+ }
- // Read from offset to end of file
- const uint64_t TotalSize = File->FileSize();
- const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
- if (Offset >= TotalSize)
- {
- ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
- m_ChildChannel->Blob(nullptr, 0);
- continue;
- }
+ const uint64_t TotalSize = File.FileSize();
+ const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
+ if (Offset >= TotalSize)
+ {
+ ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
+ m_ChildChannel->Blob(nullptr, 0);
+ }
+ else
+ {
+ const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
+ m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize());
+ }
+
+ // Continue the upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
- m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
+void
+AsyncHordeAgent::DoExecute()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute");
+
+ std::vector<const char*> ArgPtrs;
+ ArgPtrs.reserve(m_Config.Args.size());
+ for (const std::string& Arg : m_Config.Args)
+ {
+ ArgPtrs.push_back(Arg.c_str());
}
+
+ m_ChildChannel->Execute(m_Config.Executable.c_str(),
+ ArgPtrs.data(),
+ ArgPtrs.size(),
+ nullptr,
+ nullptr,
+ 0,
+ m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+
+ ZEN_INFO("remote execution started on [{}:{}] lease={}",
+ m_Config.Machine.GetConnectionAddress(),
+ m_Config.Machine.GetConnectionPort(),
+ m_Config.Machine.LeaseId);
+
+ m_State = State::Polling;
+ DoPoll();
}
void
-HordeAgent::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- bool UseWine)
+AsyncHordeAgent::DoPoll()
{
- ZEN_TRACE_CPU("HordeAgent::Execute");
- m_ChildChannel
- ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnPollResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::Poll(bool LogOutput)
+void
+AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
{
- constexpr int32_t ReadResponseTimeoutMs = 100;
- AgentMessageType Type;
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
+ switch (Type)
{
- switch (Type)
- {
- case AgentMessageType::ExecuteOutput:
+ case AgentMessageType::None:
+ if (m_ChildChannel->IsDetached())
+ {
+ ZEN_WARN("connection lost during execution");
+ Finish(false);
+ }
+ else
+ {
+ // Timeout - poll again
+ DoPoll();
+ }
+ break;
+
+ case AgentMessageType::ExecuteOutput:
+ // Silently consume remote stdout (matching LogOutput=false in provisioner)
+ DoPoll();
+ break;
+
+ case AgentMessageType::ExecuteResult:
+ {
+ int32_t ExitCode = -1;
+ if (!AsyncAgentMessageChannel::ReadExecuteResult(Data, Size, ExitCode))
{
- if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
- {
- const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
- size_t ResponseSize = m_ChildChannel->GetResponseSize();
-
- // Trim trailing newlines
- while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
- {
- --ResponseSize;
- }
-
- if (ResponseSize > 0)
- {
- const std::string_view Output(ResponseData, ResponseSize);
- ZEN_INFO("[remote] {}", Output);
- }
- }
+ // A remote with a malformed ExecuteResult cannot be trusted to report
+ // process outcome - treat as a protocol error and tear down rather than
+ // silently recording \"exited with -1\".
+ ZEN_ERROR("malformed ExecuteResult (size={}, lease={}) - disconnecting", Size, m_Config.Machine.LeaseId);
+ Finish(false);
break;
}
+ ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId);
+ Finish(ExitCode == 0, ExitCode);
+ }
+ break;
- case AgentMessageType::ExecuteResult:
+ case AgentMessageType::Exception:
+ {
+ ExceptionInfo Ex;
+ if (AsyncAgentMessageChannel::ReadException(Data, Size, Ex))
{
- if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
- {
- int32_t ExitCode;
- memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
- ZEN_INFO("remote process exited with code {}", ExitCode);
- }
- m_IsValid = false;
- return false;
+ ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
}
-
- case AgentMessageType::Exception:
+ else
{
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
- m_HasErrors = true;
- break;
+ ZEN_ERROR("malformed Exception message (size={})", Size);
}
+ Finish(false);
+ }
+ break;
- default:
- break;
- }
+ default:
+ DoPoll();
+ break;
}
-
- return m_IsValid && !m_HasErrors;
}
void
-HordeAgent::CloseConnection()
+AsyncHordeAgent::Finish(bool Success, int32_t ExitCode)
{
- if (m_ChildChannel)
+ if (m_State == State::Done)
{
- m_ChildChannel->Close();
+ return; // Already finished
}
- if (m_AgentChannel)
+
+ if (!Success)
{
- m_AgentChannel->Close();
+ ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId);
}
-}
-bool
-HordeAgent::IsValid() const
-{
- return m_IsValid && !m_HasErrors;
+ m_State = State::Done;
+
+ if (m_Socket)
+ {
+ m_Socket->Close();
+ }
+
+ if (m_OnDone)
+ {
+ AsyncAgentResult Result;
+ Result.Success = Success;
+ Result.ExitCode = ExitCode;
+ Result.CoreCount = m_Config.Machine.LogicalCores;
+
+ auto Handler = std::move(m_OnDone);
+ m_OnDone = nullptr;
+ Handler(Result);
+ }
}
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h
index e0ae89ead..a5b3248ab 100644
--- a/src/zenhorde/hordeagent.h
+++ b/src/zenhorde/hordeagent.h
@@ -10,68 +10,107 @@
#include <zencore/logbase.h>
#include <filesystem>
+#include <functional>
#include <memory>
#include <string>
+#include <vector>
+
+namespace asio {
+class io_context;
+}
namespace zen::horde {
-/** Manages the lifecycle of a single Horde compute agent.
+class AsyncComputeTransport;
+
+/** Result passed to the completion handler when an async agent finishes. */
+struct AsyncAgentResult
+{
+ bool Success = false;
+ int32_t ExitCode = -1;
+ uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine
+};
+
+/** Completion handler for async agent lifecycle. */
+using AsyncAgentCompletionHandler = std::function<void(const AsyncAgentResult&)>;
+
+/** Configuration for launching a remote zenserver instance via an async agent. */
+struct AsyncAgentConfig
+{
+ MachineInfo Machine;
+ std::vector<std::pair<std::string, std::filesystem::path>> Bundles; ///< (locator, bundleDir) pairs
+ std::string Executable;
+ std::vector<std::string> Args;
+ bool UseWine = false;
+};
+
+/** Async agent that manages the full lifecycle of a single Horde compute connection.
*
- * Handles the full connection sequence for one provisioned machine:
- * 1. Connect via TCP transport (with optional AES encryption wrapping)
- * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100)
- * 3. Perform the Attach/Fork handshake to establish the child channel
- * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol
- * 5. Execute zenserver remotely via ExecuteV2
- * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code)
+ * Driven by a state machine using callbacks on a shared io_context - no dedicated
+ * threads. Call Start() to begin the connection/handshake/upload/execute/poll
+ * sequence. The completion handler is invoked when the remote process exits or
+ * an error occurs.
*/
-class HordeAgent
+class AsyncHordeAgent : public std::enable_shared_from_this<AsyncHordeAgent>
{
public:
- explicit HordeAgent(const MachineInfo& Info);
- ~HordeAgent();
+ AsyncHordeAgent(asio::io_context& IoContext);
+ ~AsyncHordeAgent();
- HordeAgent(const HordeAgent&) = delete;
- HordeAgent& operator=(const HordeAgent&) = delete;
+ AsyncHordeAgent(const AsyncHordeAgent&) = delete;
+ AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete;
- /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel).
- * Returns false if the handshake times out or receives an unexpected message. */
- bool BeginCommunication();
+ /** Start the full agent lifecycle. The completion handler is called exactly once. */
+ void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone);
- /** Upload binary files to the remote agent.
- * @param BundleDir Directory containing .blob files.
- * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */
- bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator);
+ /** Cancel in-flight operations. The completion handler is still called (with Success=false). */
+ void Cancel();
- /** Execute a command on the remote machine. */
- void Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir = nullptr,
- const char* const* EnvVars = nullptr,
- size_t NumEnvVars = 0,
- bool UseWine = false);
+ const MachineInfo& GetMachineInfo() const { return m_Config.Machine; }
- /** Poll for output and results. Returns true if the agent is still running.
- * When LogOutput is true, remote stdout is logged via ZEN_INFO. */
- bool Poll(bool LogOutput = true);
-
- void CloseConnection();
- bool IsValid() const;
-
- const MachineInfo& GetMachineInfo() const { return m_MachineInfo; }
+ enum class State
+ {
+ Idle,
+ Connecting,
+ WaitAgentAttach,
+ SentFork,
+ WaitChildAttach,
+ Uploading,
+ Executing,
+ Polling,
+ Done
+ };
private:
LoggerRef Log() { return m_Log; }
- std::unique_ptr<ComputeSocket> m_Socket;
- std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control
- std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O
-
- LoggerRef m_Log;
- bool m_IsValid = false;
- bool m_HasErrors = false;
- MachineInfo m_MachineInfo;
+ void DoConnect();
+ void OnConnected(const std::error_code& Ec);
+ void DoWaitAgentAttach();
+ void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoSendFork();
+ void DoWaitChildAttach();
+ void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoUploadNext();
+ void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoExecute();
+ void DoPoll();
+ void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void Finish(bool Success, int32_t ExitCode = -1);
+
+ asio::io_context& m_IoContext;
+ LoggerRef m_Log;
+ State m_State = State::Idle;
+ bool m_Cancelled = false;
+
+ AsyncAgentConfig m_Config;
+ AsyncAgentCompletionHandler m_OnDone;
+ size_t m_CurrentBundleIndex = 0;
+
+ std::unique_ptr<AsyncTcpComputeTransport> m_TcpTransport;
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ std::unique_ptr<AsyncAgentMessageChannel> m_AgentChannel;
+ std::unique_ptr<AsyncAgentMessageChannel> m_ChildChannel;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
index 998134a96..bef1bdda8 100644
--- a/src/zenhorde/hordeagentmessage.cpp
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -4,337 +4,496 @@
#include <zencore/intmath.h>
-#include <cassert>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <zencore/except_fmt.h>
+#include <zencore/logging.h>
+
#include <cstring>
+#include <limits>
namespace zen::horde {
-AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
-{
-}
-
-AgentMessageChannel::~AgentMessageChannel() = default;
+// --- AsyncAgentMessageChannel ---
-void
-AgentMessageChannel::Close()
+AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext)
+: m_Socket(std::move(Socket))
+, m_ChannelId(ChannelId)
+, m_IoContext(IoContext)
+, m_TimeoutTimer(std::make_unique<asio::steady_timer>(m_Socket->GetStrand()))
{
- CreateMessage(AgentMessageType::None, 0);
- FlushMessage();
}
-void
-AgentMessageChannel::Ping()
+AsyncAgentMessageChannel::~AsyncAgentMessageChannel()
{
- CreateMessage(AgentMessageType::Ping, 0);
- FlushMessage();
+ if (m_TimeoutTimer)
+ {
+ m_TimeoutTimer->cancel();
+ }
}
-void
-AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+// --- Message building helpers ---
+
+std::vector<uint8_t>
+AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload)
{
- CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
- WriteInt32(ChannelId);
- WriteInt32(BufferSize);
- FlushMessage();
+ std::vector<uint8_t> Buf;
+ Buf.reserve(MessageHeaderLength + ReservePayload);
+ Buf.push_back(static_cast<uint8_t>(Type));
+ Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder
+ return Buf;
}
void
-AgentMessageChannel::Attach()
+AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg)
{
- CreateMessage(AgentMessageType::Attach, 0);
- FlushMessage();
+ const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength);
+ memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t));
+ m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg));
}
void
-AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value)
{
- CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
- WriteString(Path);
- WriteString(Locator);
- FlushMessage();
+ const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value);
+ Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int));
}
-void
-AgentMessageChannel::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- ExecuteProcessFlags Flags)
+int
+AsyncAgentMessageChannel::ReadInt32(ReadCursor& C)
{
- size_t RequiredSize = 50 + strlen(Exe);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- RequiredSize += strlen(Args[i]) + 10;
- }
- if (WorkingDir)
- {
- RequiredSize += strlen(WorkingDir) + 10;
- }
- for (size_t i = 0; i < NumEnvVars; ++i)
+ if (!C.CheckAvailable(sizeof(int32_t)))
{
- RequiredSize += strlen(EnvVars[i]) + 20;
+ return 0;
}
+ int32_t Value;
+ memcpy(&Value, C.Pos, sizeof(int32_t));
+ C.Pos += sizeof(int32_t);
+ return Value;
+}
- CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
- WriteString(Exe);
+void
+AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length)
+{
+ Buf.insert(Buf.end(), Data, Data + Length);
+}
- WriteUnsignedVarInt(NumArgs);
- for (size_t i = 0; i < NumArgs; ++i)
+const uint8_t*
+AsyncAgentMessageChannel::ReadFixedLengthBytes(ReadCursor& C, size_t Length)
+{
+ if (!C.CheckAvailable(Length))
{
- WriteString(Args[i]);
+ return nullptr;
}
+ const uint8_t* Data = C.Pos;
+ C.Pos += Length;
+ return Data;
+}
- WriteOptionalString(WorkingDir);
-
- // ExecuteV2 protocol requires env vars as separate key/value pairs.
- // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
- WriteUnsignedVarInt(NumEnvVars);
- for (size_t i = 0; i < NumEnvVars; ++i)
+size_t
+AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+{
+ if (Value == 0)
{
- const char* Eq = strchr(EnvVars[i], '=');
- assert(Eq != nullptr);
-
- WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
- if (*(Eq + 1) == '\0')
- {
- WriteOptionalString(nullptr);
- }
- else
- {
- WriteOptionalString(Eq + 1);
- }
+ return 1;
}
-
- WriteInt32(static_cast<int>(Flags));
- FlushMessage();
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
void
-AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value)
{
- // Blob responses are chunked to fit within the compute buffer's chunk size.
- // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
- const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
- for (size_t ChunkOffset = 0; ChunkOffset < Length;)
- {
- const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
-
- CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
- WriteInt32(static_cast<int>(ChunkOffset));
- WriteInt32(static_cast<int>(Length));
- WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
- FlushMessage();
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ const size_t StartPos = Buf.size();
+ Buf.resize(StartPos + ByteCount);
- ChunkOffset += ChunkLength;
+ uint8_t* Output = Buf.data() + StartPos;
+ for (size_t i = 1; i < ByteCount; ++i)
+ {
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
}
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
}
-AgentMessageType
-AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+size_t
+AsyncAgentMessageChannel::ReadUnsignedVarInt(ReadCursor& C)
{
- // Deferred advance: the previous response's buffer is only released when the next
- // ReadResponse is called. This allows callers to read response data between calls
- // without copying, since the pointer comes directly from the ring buffer.
- if (m_ResponseData)
+ // Need at least the leading byte to determine the encoded length.
+ if (!C.CheckAvailable(1))
{
- m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
- m_ResponseData = nullptr;
- m_ResponseLength = 0;
+ return 0;
}
- const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
- if (!Header)
+ const uint8_t FirstByte = C.Pos[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+
+ // The encoded length implied by the leading 0xFF-run may be 1..9 bytes; ensure the remaining bytes are in-bounds.
+ if (!C.CheckAvailable(NumBytes))
{
- return AgentMessageType::None;
+ return 0;
}
- uint32_t Length;
- memcpy(&Length, Header + 1, sizeof(uint32_t));
-
- Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
- if (!Header)
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
{
- return AgentMessageType::None;
+ Value <<= 8;
+ Value |= C.Pos[i];
}
- m_ResponseType = static_cast<AgentMessageType>(Header[0]);
- m_ResponseData = Header + MessageHeaderLength;
- m_ResponseLength = Length;
-
- return m_ResponseType;
+ C.Pos += NumBytes;
+ return Value;
}
void
-AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text)
{
- assert(m_ResponseType == AgentMessageType::Exception);
- const uint8_t* Pos = m_ResponseData;
- Ex.Message = ReadString(&Pos);
- Ex.Description = ReadString(&Pos);
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
}
-int
-AgentMessageChannel::ReadExecuteResult()
+void
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text)
{
- assert(m_ResponseType == AgentMessageType::ExecuteResult);
- const uint8_t* Pos = m_ResponseData;
- return ReadInt32(&Pos);
+ WriteUnsignedVarInt(Buf, Text.size());
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
}
-void
-AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+std::string_view
+AsyncAgentMessageChannel::ReadString(ReadCursor& C)
{
- assert(m_ResponseType == AgentMessageType::ReadBlob);
- const uint8_t* Pos = m_ResponseData;
- Req.Locator = ReadString(&Pos);
- Req.Offset = ReadUnsignedVarInt(&Pos);
- Req.Length = ReadUnsignedVarInt(&Pos);
+ const size_t Length = ReadUnsignedVarInt(C);
+ const uint8_t* Start = ReadFixedLengthBytes(C, Length);
+ if (C.ParseError || !Start)
+ {
+ return {};
+ }
+ return std::string_view(reinterpret_cast<const char*>(Start), Length);
}
void
-AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text)
{
- m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
- m_RequestData[0] = static_cast<uint8_t>(Type);
- m_MaxRequestSize = MaxLength;
- m_RequestSize = 0;
+ if (!Text)
+ {
+ WriteUnsignedVarInt(Buf, 0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length + 1);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
+ }
}
+// --- Send methods ---
+
void
-AgentMessageChannel::FlushMessage()
+AsyncAgentMessageChannel::Close()
{
- const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
- memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
- m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
- m_RequestSize = 0;
- m_MaxRequestSize = 0;
- m_RequestData = nullptr;
+ auto Msg = BeginMessage(AgentMessageType::None, 0);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteInt32(int Value)
+AsyncAgentMessageChannel::Ping()
{
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+ auto Msg = BeginMessage(AgentMessageType::Ping, 0);
+ FinalizeAndSend(std::move(Msg));
}
-int
-AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize)
{
- int Value;
- memcpy(&Value, *Pos, sizeof(int));
- *Pos += sizeof(int);
- return Value;
+ auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(Msg, ChannelId);
+ WriteInt32(Msg, BufferSize);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::Attach()
{
- assert(m_RequestSize + Length <= m_MaxRequestSize);
- memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
- m_RequestSize += Length;
+ auto Msg = BeginMessage(AgentMessageType::Attach, 0);
+ FinalizeAndSend(std::move(Msg));
}
-const uint8_t*
-AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+void
+AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
{
- const uint8_t* Data = *Pos;
- *Pos += Length;
- return Data;
+ auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Msg, Path);
+ WriteString(Msg, Locator);
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+void
+AsyncAgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
{
- if (Value == 0)
+ size_t ReserveSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- return 1;
+ ReserveSize += strlen(Args[i]) + 10;
+ }
+ if (WorkingDir)
+ {
+ ReserveSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ ReserveSize += strlen(EnvVars[i]) + 20;
}
- return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
-}
-void
-AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
-{
- const size_t ByteCount = MeasureUnsignedVarInt(Value);
- assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+ auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize);
+ WriteString(Msg, Exe);
- uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
- for (size_t i = 1; i < ByteCount; ++i)
+ WriteUnsignedVarInt(Msg, NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- Output[ByteCount - i] = static_cast<uint8_t>(Value);
- Value >>= 8;
+ WriteString(Msg, Args[i]);
}
- Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
- m_RequestSize += ByteCount;
+ WriteOptionalString(Msg, WorkingDir);
+
+ WriteUnsignedVarInt(Msg, NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ if (Eq == nullptr)
+ {
+ // assert() would be compiled out in release and leave *(Eq+1) as UB -
+ // refuse to build the message for a malformed KEY=VALUE string instead.
+ throw zen::runtime_error("horde agent env var at index {} missing '=' separator", i);
+ }
+
+ WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(Msg, nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Msg, Eq + 1);
+ }
+ }
+
+ WriteInt32(Msg, static_cast<int>(Flags));
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
{
- const uint8_t* Data = *Pos;
- const uint8_t FirstByte = Data[0];
- const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+ static constexpr size_t MaxBlobChunkSize = 512 * 1024;
- size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
- for (size_t i = 1; i < NumBytes; ++i)
+ // The Horde ReadBlobResponse wire format encodes both the chunk Offset and the total
+ // Length as int32. Lengths of 2 GiB or more would wrap to negative and confuse the
+ // remote parser. Refuse the send rather than produce a protocol violation.
+ if (Length > static_cast<size_t>(std::numeric_limits<int32_t>::max()))
{
- Value <<= 8;
- Value |= Data[i];
+ throw zen::runtime_error("horde ReadBlobResponse length {} exceeds int32 wire limit", Length);
}
- *Pos += NumBytes;
- return Value;
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ {
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize);
+
+ auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(Msg, static_cast<int32_t>(ChunkOffset));
+ WriteInt32(Msg, static_cast<int32_t>(Length));
+ WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength);
+ FinalizeAndSend(std::move(Msg));
+
+ ChunkOffset += ChunkLength;
+ }
}
-size_t
-AgentMessageChannel::MeasureString(const char* Text) const
+// --- Async response reading ---
+
+void
+AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler)
{
- const size_t Length = strlen(Text);
- return MeasureUnsignedVarInt(Length) + Length;
+ // Serialize all access to m_IncomingFrames / m_PendingHandler / m_TimeoutTimer onto
+ // the socket's strand; OnFrame/OnDetach also run on that strand. Without this, the
+ // timer wait completion would run on a bare io_context thread (3 concurrent run()
+ // loops in the provisioner) and race with OnFrame on m_PendingHandler.
+ asio::dispatch(m_Socket->GetStrand(), [this, TimeoutMs, Handler = std::move(Handler)]() mutable {
+ if (!m_IncomingFrames.empty())
+ {
+ std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front());
+ m_IncomingFrames.pop_front();
+
+ if (Frame.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]);
+ const uint8_t* Data = Frame.data() + MessageHeaderLength;
+ size_t Size = Frame.size() - MessageHeaderLength;
+ asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable {
+ // The Frame is captured to keep Data pointer valid
+ Handler(Type, Data, Size);
+ });
+ }
+ else
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ }
+ return;
+ }
+
+ if (m_Detached)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ return;
+ }
+
+ // No frames queued - store pending handler and arm timeout
+ m_PendingHandler = std::move(Handler);
+
+ if (TimeoutMs >= 0)
+ {
+ m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_TimeoutTimer->async_wait(asio::bind_executor(m_Socket->GetStrand(), [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return; // Cancelled - frame arrived before timeout
+ }
+
+ // Already on the strand: safe to mutate m_PendingHandler.
+ if (m_PendingHandler)
+ {
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }));
+ }
+ });
}
void
-AgentMessageChannel::WriteString(const char* Text)
+AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data)
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ if (m_PendingHandler)
+ {
+ // Cancel the timeout timer
+ m_TimeoutTimer->cancel();
+
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+
+ if (Data.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Data[0]);
+ const uint8_t* Payload = Data.data() + MessageHeaderLength;
+ size_t PayloadSize = Data.size() - MessageHeaderLength;
+ Handler(Type, Payload, PayloadSize);
+ }
+ else
+ {
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }
+ else
+ {
+ m_IncomingFrames.push_back(std::move(Data));
+ }
}
void
-AgentMessageChannel::WriteString(std::string_view Text)
+AsyncAgentMessageChannel::OnDetach()
{
- WriteUnsignedVarInt(Text.size());
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ m_Detached = true;
+
+ if (m_PendingHandler)
+ {
+ m_TimeoutTimer->cancel();
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
}
-std::string_view
-AgentMessageChannel::ReadString(const uint8_t** Pos)
+// --- Response parsing helpers ---
+
+bool
+AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex)
{
- const size_t Length = ReadUnsignedVarInt(Pos);
- const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
- return std::string_view(Start, Length);
+ ReadCursor C{Data, Data + Size, false};
+ Ex.Message = ReadString(C);
+ Ex.Description = ReadString(C);
+ if (C.ParseError)
+ {
+ Ex = {};
+ return false;
+ }
+ return true;
}
-void
-AgentMessageChannel::WriteOptionalString(const char* Text)
+bool
+AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode)
{
- // Optional strings use length+1 encoding: 0 means null/absent,
- // N>0 means a string of length N-1 follows. This matches the UE
- // FAgentMessageChannel serialization convention.
- if (!Text)
+ ReadCursor C{Data, Data + Size, false};
+ OutExitCode = ReadInt32(C);
+ return !C.ParseError;
+}
+
+static bool
+IsSafeLocator(std::string_view Locator)
+{
+ // Reject empty, overlong, path-separator-containing, parent-relative, absolute, or
+ // control-character-containing locators. The locator is used as a filename component
+ // joined with a trusted BundleDir, so the only safe characters are a restricted
+ // filename alphabet.
+ if (Locator.empty() || Locator.size() > 255)
{
- WriteUnsignedVarInt(0);
+ return false;
}
- else
+ if (Locator == "." || Locator == "..")
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length + 1);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ return false;
+ }
+ for (char Ch : Locator)
+ {
+ const unsigned char U = static_cast<unsigned char>(Ch);
+ if (U < 0x20 || U == 0x7F)
+ {
+ return false; // control / NUL / DEL
+ }
+ if (Ch == '/' || Ch == '\\' || Ch == ':')
+ {
+ return false; // path separators / drive letters
+ }
+ }
+ // Disallow leading/trailing dot or whitespace (Windows quirks + hidden-file dodges)
+ if (Locator.front() == '.' || Locator.front() == ' ' || Locator.back() == '.' || Locator.back() == ' ')
+ {
+ return false;
+ }
+ return true;
+}
+
+bool
+AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req)
+{
+ ReadCursor C{Data, Data + Size, false};
+ Req.Locator = ReadString(C);
+ Req.Offset = ReadUnsignedVarInt(C);
+ Req.Length = ReadUnsignedVarInt(C);
+ if (C.ParseError || !IsSafeLocator(Req.Locator))
+ {
+ Req = {};
+ return false;
}
+ return true;
}
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h
index 38c4375fd..fb7c5ed29 100644
--- a/src/zenhorde/hordeagentmessage.h
+++ b/src/zenhorde/hordeagentmessage.h
@@ -4,14 +4,22 @@
#include <zenbase/zenbase.h>
-#include "hordecomputechannel.h"
+#include "hordecomputesocket.h"
#include <cstddef>
#include <cstdint>
+#include <deque>
+#include <functional>
+#include <memory>
#include <string>
#include <string_view>
+#include <system_error>
#include <vector>
+namespace asio {
+class io_context;
+} // namespace asio
+
namespace zen::horde {
/** Agent message types matching the UE EAgentMessageType byte values.
@@ -55,45 +63,34 @@ struct BlobRequest
size_t Length = 0;
};
-/** Channel for sending and receiving agent messages over a ComputeChannel.
+/** Handler for async response reads. Receives the message type and a view of the payload data.
+ * The payload vector is valid until the next AsyncReadResponse call. */
+using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>;
+
+/** Async channel for sending and receiving agent messages over an AsyncComputeSocket.
*
- * Implements the Horde agent message protocol, matching the UE
- * FAgentMessageChannel serialization format exactly. Messages are framed as
- * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8;
- * integers use variable-length encoding.
+ * Send methods build messages into vectors and submit them via AsyncComputeSocket.
+ * Receives are delivered via the socket's FrameHandler callback and queued internally.
+ * AsyncReadResponse checks the queue and invokes the handler, with optional timeout.
*
- * The protocol has two directions:
- * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob
- * - Responses (remote -> initiator): ReadResponse returns the type, then call the
- * appropriate Read* method to parse the payload.
+ * All operations must be externally serialized (e.g. via the socket's strand).
*/
-class AgentMessageChannel
+class AsyncAgentMessageChannel
{
public:
- explicit AgentMessageChannel(Ref<ComputeChannel> Channel);
- ~AgentMessageChannel();
+ AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext);
+ ~AsyncAgentMessageChannel();
- AgentMessageChannel(const AgentMessageChannel&) = delete;
- AgentMessageChannel& operator=(const AgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete;
- // --- Requests (Initiator -> Remote) ---
+ // --- Requests (fire-and-forget sends) ---
- /** Close the channel. */
void Close();
-
- /** Send a keepalive ping. */
void Ping();
-
- /** Fork communication to a new channel with the given ID and buffer size. */
void Fork(int ChannelId, int BufferSize);
-
- /** Send an attach request (used during channel setup handshake). */
void Attach();
-
- /** Request the remote agent to write files from the given bundle locator. */
void UploadFiles(const char* Path, const char* Locator);
-
- /** Execute a process on the remote machine. */
void Execute(const char* Exe,
const char* const* Args,
size_t NumArgs,
@@ -101,61 +98,85 @@ public:
const char* const* EnvVars,
size_t NumEnvVars,
ExecuteProcessFlags Flags = ExecuteProcessFlags::None);
-
- /** Send blob data in response to a ReadBlob request. */
void Blob(const uint8_t* Data, size_t Length);
- // --- Responses (Remote -> Initiator) ---
-
- /** Read the next response message. Returns the message type, or None on timeout.
- * After this returns, use GetResponseData()/GetResponseSize() or the typed
- * Read* methods to access the payload. */
- AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr);
+ // --- Async response reading ---
- const void* GetResponseData() const { return m_ResponseData; }
- size_t GetResponseSize() const { return m_ResponseLength; }
+ /** Read the next response. If a frame is already queued, the handler is posted immediately.
+ * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler
+ * with AgentMessageType::None. */
+ void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler);
- /** Parse an Exception response payload. */
- void ReadException(ExceptionInfo& Ex);
+ /** Called by the socket's FrameHandler when a frame arrives for this channel. */
+ void OnFrame(std::vector<uint8_t> Data);
- /** Parse an ExecuteResult response payload. Returns the exit code. */
- int ReadExecuteResult();
+ /** Called by the socket's DetachHandler. */
+ void OnDetach();
- /** Parse a ReadBlob response payload into a BlobRequest. */
- void ReadBlobRequest(BlobRequest& Req);
+ /** Returns true if the channel has been detached (connection lost). */
+ bool IsDetached() const { return m_Detached; }
-private:
- static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)]
+ // --- Response parsing helpers ---
- Ref<ComputeChannel> m_Channel;
+ /** Parse an Exception message payload. Returns false on malformed/truncated input. */
+ [[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex);
- uint8_t* m_RequestData = nullptr;
- size_t m_RequestSize = 0;
- size_t m_MaxRequestSize = 0;
+ /** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */
+ [[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode);
- AgentMessageType m_ResponseType = AgentMessageType::None;
- const uint8_t* m_ResponseData = nullptr;
- size_t m_ResponseLength = 0;
+ /** Parse a ReadBlob message payload. Returns false on malformed/truncated input or
+ * if the Locator contains characters that would not be safe to use as a path component. */
+ [[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req);
- void CreateMessage(AgentMessageType Type, size_t MaxLength);
- void FlushMessage();
+private:
+ static constexpr size_t MessageHeaderLength = 5;
+
+ // Message building helpers
+ std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload);
+ void FinalizeAndSend(std::vector<uint8_t> Msg);
+
+ /** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */
+ struct ReadCursor
+ {
+ const uint8_t* Pos = nullptr;
+ const uint8_t* End = nullptr;
+ bool ParseError = false;
+
+ [[nodiscard]] bool CheckAvailable(size_t N)
+ {
+ if (ParseError || static_cast<size_t>(End - Pos) < N)
+ {
+ ParseError = true;
+ return false;
+ }
+ return true;
+ }
+ };
+
+ static void WriteInt32(std::vector<uint8_t>& Buf, int Value);
+ static int ReadInt32(ReadCursor& C);
+
+ static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length);
+ static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length);
- void WriteInt32(int Value);
- static int ReadInt32(const uint8_t** Pos);
+ static size_t MeasureUnsignedVarInt(size_t Value);
+ static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value);
+ static size_t ReadUnsignedVarInt(ReadCursor& C);
- void WriteFixedLengthBytes(const uint8_t* Data, size_t Length);
- static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length);
+ static void WriteString(std::vector<uint8_t>& Buf, const char* Text);
+ static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text);
+ static std::string_view ReadString(ReadCursor& C);
- static size_t MeasureUnsignedVarInt(size_t Value);
- void WriteUnsignedVarInt(size_t Value);
- static size_t ReadUnsignedVarInt(const uint8_t** Pos);
+ static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text);
- size_t MeasureString(const char* Text) const;
- void WriteString(const char* Text);
- void WriteString(std::string_view Text);
- static std::string_view ReadString(const uint8_t** Pos);
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ int m_ChannelId;
+ asio::io_context& m_IoContext;
- void WriteOptionalString(const char* Text);
+ std::deque<std::vector<uint8_t>> m_IncomingFrames;
+ AsyncResponseHandler m_PendingHandler;
+ std::unique_ptr<asio::steady_timer> m_TimeoutTimer;
+ bool m_Detached = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp
index d3974bc28..8493a9456 100644
--- a/src/zenhorde/hordebundle.cpp
+++ b/src/zenhorde/hordebundle.cpp
@@ -10,6 +10,7 @@
#include <zencore/logging.h>
#include <zencore/process.h>
#include <zencore/trace.h>
+#include <zencore/uid.h>
#include <algorithm>
#include <chrono>
@@ -48,7 +49,7 @@ static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1
static constexpr size_t BlobTypeSize = 20;
-// ─── VarInt helpers (UE format) ─────────────────────────────────────────────
+// --- VarInt helpers (UE format) ---------------------------------------------
static size_t
MeasureVarInt(size_t Value)
@@ -57,7 +58,7 @@ MeasureVarInt(size_t Value)
{
return 1;
}
- return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1;
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
static void
@@ -76,7 +77,7 @@ WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value)
Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
}
-// ─── Binary helpers ─────────────────────────────────────────────────────────
+// --- Binary helpers ---------------------------------------------------------
static void
WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value)
@@ -121,7 +122,7 @@ PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value)
memcpy(Buffer.data() + Offset, &Value, 4);
}
-// ─── Packet builder ─────────────────────────────────────────────────────────
+// --- Packet builder ---------------------------------------------------------
// Builds a single uncompressed Horde V2 packet. Layout:
// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header)
@@ -229,7 +230,7 @@ struct PacketBuilder
{
AlignTo4(Data);
- // ── Type table: count(int32) + count * BlobTypeSize bytes ──
+ // -- Type table: count(int32) + count * BlobTypeSize bytes --
const int32_t TypeTableOffset = static_cast<int32_t>(Data.size());
WriteLE32(Data, static_cast<int32_t>(Types.size()));
for (const uint8_t* TypeEntry : Types)
@@ -237,12 +238,12 @@ struct PacketBuilder
WriteBytes(Data, TypeEntry, BlobTypeSize);
}
- // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ──
+ // -- Import table: count(int32) + (count+1) offsets(int32 each) + import data --
const int32_t ImportTableOffset = static_cast<int32_t>(Data.size());
const int32_t ImportCount = static_cast<int32_t>(Imports.size());
WriteLE32(Data, ImportCount);
- // Reserve space for (count+1) offset entries — will be patched below
+ // Reserve space for (count+1) offset entries - will be patched below
const size_t ImportOffsetsStart = Data.size();
for (int32_t i = 0; i <= ImportCount; ++i)
{
@@ -266,7 +267,7 @@ struct PacketBuilder
// Sentinel offset (points past the last import's data)
PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size()));
- // ── Export table: count(int32) + (count+1) offsets(int32 each) ──
+ // -- Export table: count(int32) + (count+1) offsets(int32 each) --
const int32_t ExportTableOffset = static_cast<int32_t>(Data.size());
const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size());
WriteLE32(Data, ExportCount);
@@ -278,7 +279,7 @@ struct PacketBuilder
// Sentinel: points to the start of the type table (end of export data region)
WriteLE32(Data, TypeTableOffset);
- // ── Patch header ──
+ // -- Patch header --
// PacketLength = total packet size including the 8-byte header
const int32_t PacketLength = static_cast<int32_t>(Data.size());
PatchLE32(Data, 4, PacketLength);
@@ -290,7 +291,7 @@ struct PacketBuilder
}
};
-// ─── Encoded packet wrapper ─────────────────────────────────────────────────
+// --- Encoded packet wrapper -------------------------------------------------
// Wraps an uncompressed packet with the encoded header:
// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes
@@ -327,24 +328,22 @@ EncodePacket(std::vector<uint8_t> UncompressedPacket)
return Encoded;
}
-// ─── Bundle blob name generation ────────────────────────────────────────────
+// --- Bundle blob name generation --------------------------------------------
static std::string
GenerateBlobName()
{
- static std::atomic<uint32_t> s_Counter{0};
-
- const int Pid = GetCurrentProcessId();
-
- auto Now = std::chrono::steady_clock::now().time_since_epoch();
- auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count();
-
- ExtendableStringBuilder<64> Name;
- Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1);
- return std::string(Name.ToView());
+ // Oid is a 12-byte identifier built from a timestamp, a monotonic serial number
+ // initialised from std::random_device, and a per-process run id also drawn from
+ // std::random_device. The 24-hex-char rendering gives ~80 bits of effective
+ // name-prediction entropy, so a local attacker cannot race-create the blob
+ // path before we open it. Previously the name was pid+ms+counter, which two
+ // zenserver processes with the same PID could collide on and which was
+ // entirely predictable.
+ return zen::Oid::NewOid().ToString();
}
-// ─── File info for bundling ─────────────────────────────────────────────────
+// --- File info for bundling -------------------------------------------------
struct FileInfo
{
@@ -357,7 +356,7 @@ struct FileInfo
IoHash RootExportHash; // IoHash of the root export for this file
};
-// ─── CreateBundle implementation ────────────────────────────────────────────
+// --- CreateBundle implementation --------------------------------------------
bool
BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult)
@@ -534,7 +533,7 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil
FileInfo& Info = ValidFiles[i];
DirImports.push_back(Info.DirectoryExportImportIndex);
- // IoHash of target (20 bytes) — import is consumed sequentially from the
+ // IoHash of target (20 bytes) - import is consumed sequentially from the
// export's import list by ReadBlobRef, not encoded in the payload
WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash));
// name (string)
@@ -557,8 +556,16 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil
std::vector<uint8_t> UncompressedPacket = Packet.Finish();
std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket));
- // Write .blob file
+ // Write .blob file. Refuse to proceed if a file with this name already exists -
+ // the Oid-based BlobName should make collisions astronomically unlikely, so an
+ // existing file implies either an extraordinary collision or an attacker having
+ // pre-seeded the path; either way, we do not want to overwrite it.
const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob");
+ if (std::filesystem::exists(BlobFilePath, Ec))
+ {
+ ZEN_ERROR("blob file already exists at {} - refusing to overwrite", BlobFilePath.string());
+ return false;
+ }
{
BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec);
if (Ec)
@@ -574,8 +581,10 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil
Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex;
const std::string LocatorStr(Locator.ToView());
- // Write .ref file (use first file's name as the ref base)
- const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref");
+ // Write .ref file. Include the Oid-based BlobName so that two concurrent
+ // CreateBundle() calls into the same OutputDir that happen to share the first
+ // filename don't clobber each other's ref file.
+ const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + "." + BlobName + ".Bundle.ref");
{
BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec);
if (Ec)
diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp
index 0eefc57c6..762edce06 100644
--- a/src/zenhorde/hordeclient.cpp
+++ b/src/zenhorde/hordeclient.cpp
@@ -4,6 +4,7 @@
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
#include <zencore/memoryview.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhorde/hordeclient.h>
#include <zenhttp/httpclient.h>
@@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client"))
+HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client")
{
}
@@ -32,16 +33,24 @@ HordeClient::Initialize()
Settings.RetryCount = 1;
Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests};
- if (!m_Config.AuthToken.empty())
+ if (m_Config.AccessTokenProvider)
{
+ Settings.AccessTokenProvider = m_Config.AccessTokenProvider;
+ }
+ else if (!m_Config.AuthToken.empty())
+ {
+ // Static tokens have no wire-provided expiry. Synthesising \"now + 24h\" is wrong
+ // in both directions: if the real token expires before 24h we keep sending it after
+ // it dies; if it's long-lived we force unnecessary re-auth churn every day. Use the
+ // never-expires sentinel, matching zenhttp's CreateFromStaticToken.
Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken {
- return HttpClientAccessToken(token, HttpClientAccessToken::Clock::now() + std::chrono::hours{24});
+ return HttpClientAccessToken(token, HttpClientAccessToken::TimePoint::max());
};
}
m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings);
- if (!m_Config.AuthToken.empty())
+ if (Settings.AccessTokenProvider)
{
if (!m_Http->Authenticate())
{
@@ -63,24 +72,21 @@ HordeClient::BuildRequestBody() const
Requirements["pool"] = m_Config.Pool;
}
- std::string Condition;
-#if ZEN_PLATFORM_WINDOWS
ExtendableStringBuilder<256> CondBuf;
+#if ZEN_PLATFORM_WINDOWS
CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')";
- Condition = std::string(CondBuf);
#elif ZEN_PLATFORM_MAC
- Condition = "OSFamily == 'MacOS'";
+ CondBuf << "OSFamily == 'MacOS'";
#else
- Condition = "OSFamily == 'Linux'";
+ CondBuf << "OSFamily == 'Linux'";
#endif
if (!m_Config.Condition.empty())
{
- Condition += " ";
- Condition += m_Config.Condition;
+ CondBuf << " " << m_Config.Condition;
}
- Requirements["condition"] = Condition;
+ Requirements["condition"] = std::string(CondBuf);
Requirements["exclusive"] = true;
json11::Json::object Connection;
@@ -156,39 +162,27 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus
return false;
}
- OutCluster.ClusterId = ClusterIdVal.string_value();
- return true;
-}
-
-bool
-HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize)
-{
- if (Hex.size() != OutSize * 2)
+ // A server-returned ClusterId is interpolated directly into the request URL below
+ // (api/v2/compute/<ClusterId>), so a compromised or MITM'd Horde server could
+ // otherwise inject additional path segments or query strings. Constrain to a
+ // conservative identifier alphabet.
+ const std::string& ClusterIdStr = ClusterIdVal.string_value();
+ if (ClusterIdStr.size() > 64)
{
+ ZEN_WARN("rejecting overlong clusterId ({} bytes) in cluster resolution response", ClusterIdStr.size());
return false;
}
-
- for (size_t i = 0; i < OutSize; ++i)
+ static constexpr AsciiSet ValidClusterIdCharactersSet{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-"};
+ if (!AsciiSet::HasOnly(ClusterIdStr, ValidClusterIdCharactersSet))
{
- auto HexToByte = [](char c) -> int {
- if (c >= '0' && c <= '9')
- return c - '0';
- if (c >= 'a' && c <= 'f')
- return c - 'a' + 10;
- if (c >= 'A' && c <= 'F')
- return c - 'A' + 10;
- return -1;
- };
-
- const int Hi = HexToByte(Hex[i * 2]);
- const int Lo = HexToByte(Hex[i * 2 + 1]);
- if (Hi < 0 || Lo < 0)
- {
- return false;
- }
- Out[i] = static_cast<uint8_t>((Hi << 4) | Lo);
+ ZEN_WARN("rejecting clusterId with unsafe character in cluster resolution response");
+ return false;
}
+ OutCluster.ClusterId = ClusterIdStr;
+
+ ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId);
+
return true;
}
@@ -197,8 +191,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
{
ZEN_TRACE_CPU("HordeClient::RequestMachine");
- ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str());
-
ExtendableStringBuilder<128> ResourcePath;
ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str());
@@ -318,11 +310,15 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
}
else if (Prop.starts_with("LogicalCores="))
{
- LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13));
+ LogicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(13)).value_or(0);
}
else if (Prop.starts_with("PhysicalCores="))
{
- PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14));
+ PhysicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(14)).value_or(0);
+ }
+ else if (Prop.starts_with("Pool="))
+ {
+ OutMachine.Pool = Prop.substr(5);
}
}
}
@@ -367,10 +363,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
OutMachine.LeaseId = LeaseIdVal.string_value();
}
- ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}",
+ ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}",
OutMachine.GetConnectionAddress(),
OutMachine.GetConnectionPort(),
+ ToString(OutMachine.Mode),
OutMachine.LogicalCores,
+ OutMachine.Pool,
OutMachine.LeaseId);
return true;
diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp
deleted file mode 100644
index 0d032b5d5..000000000
--- a/src/zenhorde/hordecomputebuffer.cpp
+++ /dev/null
@@ -1,454 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "hordecomputebuffer.h"
-
-#include <algorithm>
-#include <cassert>
-#include <chrono>
-#include <condition_variable>
-#include <cstring>
-
-namespace zen::horde {
-
-// Simplified ring buffer implementation for in-process use only.
-// Uses a single contiguous buffer with write/read cursors and
-// mutex+condvar for synchronization. This is simpler than the UE version
-// which uses lock-free atomics and shared memory, but sufficient for our
-// use case where we're the initiator side of the compute protocol.
-
-struct ComputeBuffer::Detail : TRefCounted<Detail>
-{
- std::vector<uint8_t> Data;
- size_t NumChunks = 0;
- size_t ChunkLength = 0;
-
- // Current write state
- size_t WriteChunkIdx = 0;
- size_t WriteOffset = 0;
- bool WriteComplete = false;
-
- // Current read state
- size_t ReadChunkIdx = 0;
- size_t ReadOffset = 0;
- bool Detached = false;
-
- // Per-chunk written length
- std::vector<size_t> ChunkWrittenLength;
- std::vector<bool> ChunkFinished; // Writer moved to next chunk
-
- std::mutex Mutex;
- std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes
- std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space
-
- bool HasWriter = false;
- bool HasReader = false;
-
- uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; }
- const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; }
-};
-
-// ComputeBuffer
-
-ComputeBuffer::ComputeBuffer()
-{
-}
-ComputeBuffer::~ComputeBuffer()
-{
-}
-
-bool
-ComputeBuffer::CreateNew(const Params& InParams)
-{
- auto* NewDetail = new Detail();
- NewDetail->NumChunks = InParams.NumChunks;
- NewDetail->ChunkLength = InParams.ChunkLength;
- NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0);
- NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0);
- NewDetail->ChunkFinished.resize(InParams.NumChunks, false);
-
- m_Detail = NewDetail;
- return true;
-}
-
-void
-ComputeBuffer::Close()
-{
- m_Detail = nullptr;
-}
-
-bool
-ComputeBuffer::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-ComputeBufferReader
-ComputeBuffer::CreateReader()
-{
- assert(m_Detail);
- m_Detail->HasReader = true;
- return ComputeBufferReader(m_Detail);
-}
-
-ComputeBufferWriter
-ComputeBuffer::CreateWriter()
-{
- assert(m_Detail);
- m_Detail->HasWriter = true;
- return ComputeBufferWriter(m_Detail);
-}
-
-// ComputeBufferReader
-
-ComputeBufferReader::ComputeBufferReader()
-{
-}
-ComputeBufferReader::~ComputeBufferReader()
-{
-}
-
-ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default;
-ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default;
-ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default;
-ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default;
-
-ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
-{
-}
-
-void
-ComputeBufferReader::Close()
-{
- m_Detail = nullptr;
-}
-
-void
-ComputeBufferReader::Detach()
-{
- if (m_Detail)
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- m_Detail->Detached = true;
- m_Detail->ReadCV.notify_all();
- }
-}
-
-bool
-ComputeBufferReader::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-bool
-ComputeBufferReader::IsComplete() const
-{
- if (!m_Detail)
- {
- return true;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- if (m_Detail->Detached)
- {
- return true;
- }
- return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx &&
- m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx];
-}
-
-void
-ComputeBufferReader::AdvanceReadPosition(size_t Size)
-{
- if (!m_Detail)
- {
- return;
- }
-
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
-
- m_Detail->ReadOffset += Size;
-
- // Check if we need to move to next chunk
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
- {
- const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
- m_Detail->ReadChunkIdx = NextChunk;
- m_Detail->ReadOffset = 0;
- m_Detail->WriteCV.notify_all();
- }
-
- m_Detail->ReadCV.notify_all();
-}
-
-size_t
-ComputeBufferReader::GetMaxReadSize() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-}
-
-const uint8_t*
-ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut)
-{
- if (!m_Detail)
- {
- return nullptr;
- }
-
- std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
-
- auto Predicate = [&]() -> bool {
- if (m_Detail->Detached)
- {
- return true;
- }
-
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-
- if (Available >= MinSize)
- {
- return true;
- }
-
- // If chunk is finished and we've read everything, try to move to next
- if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
- {
- if (m_Detail->WriteComplete)
- {
- return true; // End of stream
- }
- // Move to next chunk
- const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
- m_Detail->ReadChunkIdx = NextChunk;
- m_Detail->ReadOffset = 0;
- m_Detail->WriteCV.notify_all();
- return false; // Re-check with new chunk
- }
-
- if (m_Detail->WriteComplete)
- {
- return true; // End of stream
- }
-
- return false;
- };
-
- if (TimeoutMs < 0)
- {
- m_Detail->ReadCV.wait(Lock, Predicate);
- }
- else
- {
- if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
- {
- if (OutTimedOut)
- {
- *OutTimedOut = true;
- }
- return nullptr;
- }
- }
-
- if (m_Detail->Detached)
- {
- return nullptr;
- }
-
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-
- if (Available < MinSize)
- {
- return nullptr; // End of stream
- }
-
- return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset;
-}
-
-size_t
-ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut)
-{
- const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut);
- if (!Data)
- {
- return 0;
- }
-
- const size_t Available = GetMaxReadSize();
- const size_t ToCopy = std::min(Available, MaxSize);
- memcpy(Buffer, Data, ToCopy);
- AdvanceReadPosition(ToCopy);
- return ToCopy;
-}
-
-// ComputeBufferWriter
-
-ComputeBufferWriter::ComputeBufferWriter() = default;
-ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default;
-ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default;
-ComputeBufferWriter::~ComputeBufferWriter() = default;
-ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default;
-ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default;
-
-ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
-{
-}
-
-void
-ComputeBufferWriter::Close()
-{
- if (m_Detail)
- {
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- if (!m_Detail->WriteComplete)
- {
- m_Detail->WriteComplete = true;
- m_Detail->ReadCV.notify_all();
- }
- }
- m_Detail = nullptr;
- }
-}
-
-bool
-ComputeBufferWriter::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-void
-ComputeBufferWriter::MarkComplete()
-{
- if (m_Detail)
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- m_Detail->WriteComplete = true;
- m_Detail->ReadCV.notify_all();
- }
-}
-
-void
-ComputeBufferWriter::AdvanceWritePosition(size_t Size)
-{
- if (!m_Detail || Size == 0)
- {
- return;
- }
-
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- m_Detail->ChunkWrittenLength[WriteChunk] += Size;
- m_Detail->WriteOffset += Size;
- m_Detail->ReadCV.notify_all();
-}
-
-size_t
-ComputeBufferWriter::GetMaxWriteSize() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
-}
-
-size_t
-ComputeBufferWriter::GetChunkMaxLength() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- return m_Detail->ChunkLength;
-}
-
-size_t
-ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs)
-{
- uint8_t* Dest = WaitToWrite(1, TimeoutMs);
- if (!Dest)
- {
- return 0;
- }
-
- const size_t Available = GetMaxWriteSize();
- const size_t ToCopy = std::min(Available, MaxSize);
- memcpy(Dest, Buffer, ToCopy);
- AdvanceWritePosition(ToCopy);
- return ToCopy;
-}
-
-uint8_t*
-ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs)
-{
- if (!m_Detail)
- {
- return nullptr;
- }
-
- std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
-
- if (m_Detail->WriteComplete)
- {
- return nullptr;
- }
-
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
-
- // If current chunk has enough space, return pointer
- if (Available >= MinSize)
- {
- return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk];
- }
-
- // Current chunk is full - mark it as finished and move to next.
- // The writer cannot advance until the reader has fully consumed the next chunk,
- // preventing the writer from overwriting data the reader hasn't processed yet.
- m_Detail->ChunkFinished[WriteChunk] = true;
- m_Detail->ReadCV.notify_all();
-
- const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks;
-
- // Wait until reader has consumed the next chunk
- auto Predicate = [&]() -> bool {
- // Check if read has moved past this chunk
- return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached;
- };
-
- if (TimeoutMs < 0)
- {
- m_Detail->WriteCV.wait(Lock, Predicate);
- }
- else
- {
- if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
- {
- return nullptr;
- }
- }
-
- if (m_Detail->Detached)
- {
- return nullptr;
- }
-
- // Reset next chunk
- m_Detail->ChunkWrittenLength[NextChunk] = 0;
- m_Detail->ChunkFinished[NextChunk] = false;
- m_Detail->WriteChunkIdx = NextChunk;
- m_Detail->WriteOffset = 0;
-
- return m_Detail->ChunkPtr(NextChunk);
-}
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h
deleted file mode 100644
index 64ef91b7a..000000000
--- a/src/zenhorde/hordecomputebuffer.h
+++ /dev/null
@@ -1,136 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zenbase/refcount.h>
-
-#include <cstddef>
-#include <cstdint>
-#include <mutex>
-#include <vector>
-
-namespace zen::horde {
-
-class ComputeBufferReader;
-class ComputeBufferWriter;
-
-/** Simplified in-process ring buffer for the Horde compute protocol.
- *
- * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files,
- * this implementation uses plain heap-allocated memory since we only need in-process
- * communication between channel and transport threads. The buffer is divided into
- * fixed-size chunks; readers and writers block when no space is available.
- */
-class ComputeBuffer
-{
-public:
- struct Params
- {
- size_t NumChunks = 2;
- size_t ChunkLength = 512 * 1024;
- };
-
- ComputeBuffer();
- ~ComputeBuffer();
-
- ComputeBuffer(const ComputeBuffer&) = delete;
- ComputeBuffer& operator=(const ComputeBuffer&) = delete;
-
- bool CreateNew(const Params& InParams);
- void Close();
-
- bool IsValid() const;
-
- ComputeBufferReader CreateReader();
- ComputeBufferWriter CreateWriter();
-
-private:
- struct Detail;
- Ref<Detail> m_Detail;
-
- friend class ComputeBufferReader;
- friend class ComputeBufferWriter;
-};
-
-/** Read endpoint for a ComputeBuffer.
- *
- * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer
- * directly into the buffer memory (zero-copy); the caller must call
- * AdvanceReadPosition() after consuming the data.
- */
-class ComputeBufferReader
-{
-public:
- ComputeBufferReader();
- ComputeBufferReader(const ComputeBufferReader&);
- ComputeBufferReader(ComputeBufferReader&&) noexcept;
- ~ComputeBufferReader();
-
- ComputeBufferReader& operator=(const ComputeBufferReader&);
- ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept;
-
- void Close();
- void Detach();
- bool IsValid() const;
- bool IsComplete() const;
-
- void AdvanceReadPosition(size_t Size);
- size_t GetMaxReadSize() const;
-
- /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */
- size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
-
- /** Wait until at least MinSize bytes are available and return a direct pointer.
- * Returns nullptr on timeout or if the writer has completed. */
- const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
-
-private:
- friend class ComputeBuffer;
- explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail);
-
- Ref<ComputeBuffer::Detail> m_Detail;
-};
-
-/** Write endpoint for a ComputeBuffer.
- *
- * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer
- * directly into the buffer memory (zero-copy); the caller must call
- * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal
- * that no more data will be written.
- */
-class ComputeBufferWriter
-{
-public:
- ComputeBufferWriter();
- ComputeBufferWriter(const ComputeBufferWriter&);
- ComputeBufferWriter(ComputeBufferWriter&&) noexcept;
- ~ComputeBufferWriter();
-
- ComputeBufferWriter& operator=(const ComputeBufferWriter&);
- ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept;
-
- void Close();
- bool IsValid() const;
-
- /** Signal that no more data will be written. Unblocks any waiting readers. */
- void MarkComplete();
-
- void AdvanceWritePosition(size_t Size);
- size_t GetMaxWriteSize() const;
- size_t GetChunkMaxLength() const;
-
- /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */
- size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1);
-
- /** Wait until at least MinSize bytes of write space are available and return a direct pointer.
- * Returns nullptr on timeout. */
- uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1);
-
-private:
- friend class ComputeBuffer;
- explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail);
-
- Ref<ComputeBuffer::Detail> m_Detail;
-};
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp
deleted file mode 100644
index ee2a6f327..000000000
--- a/src/zenhorde/hordecomputechannel.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "hordecomputechannel.h"
-
-namespace zen::horde {
-
-ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter)
-: Reader(std::move(InReader))
-, Writer(std::move(InWriter))
-{
-}
-
-bool
-ComputeChannel::IsValid() const
-{
- return Reader.IsValid() && Writer.IsValid();
-}
-
-size_t
-ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs)
-{
- return Writer.Write(Data, Size, TimeoutMs);
-}
-
-size_t
-ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs)
-{
- return Reader.Read(Data, Size, TimeoutMs);
-}
-
-void
-ComputeChannel::MarkComplete()
-{
- Writer.MarkComplete();
-}
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h
deleted file mode 100644
index c1dff20e4..000000000
--- a/src/zenhorde/hordecomputechannel.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include "hordecomputebuffer.h"
-
-namespace zen::horde {
-
-/** Bidirectional communication channel using a pair of compute buffers.
- *
- * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter
- * (for sending data). Used by ComputeSocket to represent one logical channel
- * within a multiplexed connection.
- */
-class ComputeChannel : public TRefCounted<ComputeChannel>
-{
-public:
- ComputeBufferReader Reader;
- ComputeBufferWriter Writer;
-
- ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter);
-
- bool IsValid() const;
-
- size_t Send(const void* Data, size_t Size, int TimeoutMs = -1);
- size_t Recv(void* Data, size_t Size, int TimeoutMs = -1);
-
- /** Signal that no more data will be sent on this channel. */
- void MarkComplete();
-};
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp
index 6ef67760c..92a56c077 100644
--- a/src/zenhorde/hordecomputesocket.cpp
+++ b/src/zenhorde/hordecomputesocket.cpp
@@ -6,198 +6,326 @@
namespace zen::horde {
-ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport)
-: m_Log(zen::logging::Get("horde.socket"))
+AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext)
+: m_Log(zen::logging::Get("horde.socket.async"))
, m_Transport(std::move(Transport))
+, m_Strand(asio::make_strand(IoContext))
+, m_PingTimer(m_Strand)
{
}
-ComputeSocket::~ComputeSocket()
+AsyncComputeSocket::~AsyncComputeSocket()
{
- // Shutdown order matters: first stop the ping thread, then unblock send threads
- // by detaching readers, then join send threads, and finally close the transport
- // to unblock the recv thread (which is blocked on RecvMessage).
- {
- std::lock_guard<std::mutex> Lock(m_PingMutex);
- m_PingShouldStop = true;
- m_PingCV.notify_all();
- }
-
- for (auto& Reader : m_Readers)
- {
- Reader.Detach();
- }
-
- for (auto& [Id, Thread] : m_SendThreads)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
+ Close();
+}
- m_Transport->Close();
+void
+AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach)
+{
+ m_FrameHandlers[ChannelId] = std::move(OnFrame);
+ m_DetachHandlers[ChannelId] = std::move(OnDetach);
+}
- if (m_RecvThread.joinable())
- {
- m_RecvThread.join();
- }
- if (m_PingThread.joinable())
- {
- m_PingThread.join();
- }
+void
+AsyncComputeSocket::StartRecvPump()
+{
+ StartPingTimer();
+ DoRecvHeader();
}
-Ref<ComputeChannel>
-ComputeSocket::CreateChannel(int ChannelId)
+void
+AsyncComputeSocket::DoRecvHeader()
{
- ComputeBuffer::Params Params;
+ auto Self = shared_from_this();
+ m_Transport->AsyncRead(&m_RecvHeader,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
- ComputeBuffer RecvBuffer;
- if (!RecvBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_Closed)
+ {
+ return;
+ }
- ComputeBuffer SendBuffer;
- if (!SendBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_RecvHeader.Size >= 0)
+ {
+ DoRecvPayload(m_RecvHeader);
+ }
+ else if (m_RecvHeader.Size == ControlDetach)
+ {
+ if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second)
+ {
+ It->second();
+ }
+ DoRecvHeader();
+ }
+ else if (m_RecvHeader.Size == ControlPing)
+ {
+ DoRecvHeader();
+ }
+ else
+ {
+ ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size);
+ }
+ }));
+}
- Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter()));
+void
+AsyncComputeSocket::DoRecvPayload(FrameHeader Header)
+{
+ auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size));
+ auto Self = shared_from_this();
- // Attach recv buffer writer (transport recv thread writes into this)
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter());
- }
+ m_Transport->AsyncRead(PayloadBuf->data(),
+ PayloadBuf->size(),
+ asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message());
+ HandleError();
+ }
+ return;
+ }
- // Attach send buffer reader (send thread reads from this)
- {
- ComputeBufferReader Reader = SendBuffer.CreateReader();
- m_Readers.push_back(Reader);
- m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader)));
- }
+ if (m_Closed)
+ {
+ return;
+ }
+
+ if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second)
+ {
+ It->second(std::move(*PayloadBuf));
+ }
+ else
+ {
+ ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
+ }
- return Channel;
+ DoRecvHeader();
+ }));
}
void
-ComputeSocket::StartCommunication()
+AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler)
{
- m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this);
- m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this);
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
+ {
+ if (Handler)
+ {
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
+ }
+ return;
+ }
+
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = static_cast<int32_t>(Data.size());
+ Write.Data = std::move(Data);
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::PingThreadProc()
+AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler)
{
- while (true)
- {
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
{
- std::unique_lock<std::mutex> Lock(m_PingMutex);
- if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; }))
+ if (Handler)
{
- break;
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
}
+ return;
}
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- FrameHeader Header;
- Header.Channel = 0;
- Header.Size = ControlPing;
- m_Transport->SendMessage(&Header, sizeof(Header));
- }
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = ControlDetach;
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::RecvThreadProc()
+AsyncComputeSocket::FlushNextSend()
{
- // Writers are cached locally to avoid taking m_WritersMutex on every frame.
- // The shared m_Writers map is only accessed when a channel is seen for the first time.
- std::unordered_map<int, ComputeBufferWriter> CachedWriters;
+ if (m_SendQueue.empty() || m_Closed)
+ {
+ return;
+ }
- FrameHeader Header;
- while (m_Transport->RecvMessage(&Header, sizeof(Header)))
+ PendingWrite& Front = m_SendQueue.front();
+
+ if (Front.Data.empty())
{
- if (Header.Size >= 0)
- {
- // Data frame
- auto It = CachedWriters.find(Header.Channel);
- if (It == CachedWriters.end())
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto WIt = m_Writers.find(Header.Channel);
- if (WIt == m_Writers.end())
- {
- ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
- // Skip the data
- std::vector<uint8_t> Discard(Header.Size);
- m_Transport->RecvMessage(Discard.data(), Header.Size);
- continue;
- }
- It = CachedWriters.emplace(Header.Channel, WIt->second).first;
- }
+ // Control frame - header only
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
- ComputeBufferWriter& Writer = It->second;
- uint8_t* Dest = Writer.WaitToWrite(Header.Size);
- if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size))
- {
- ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size);
- return;
- }
- Writer.AdvanceWritePosition(Header.Size);
- }
- else if (Header.Size == ControlDetach)
- {
- // Detach the recv buffer for this channel
- CachedWriters.erase(Header.Channel);
+ if (Handler)
+ {
+ Handler(Ec);
+ }
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto It = m_Writers.find(Header.Channel);
- if (It != m_Writers.end())
- {
- It->second.MarkComplete();
- m_Writers.erase(It);
- }
- }
- else if (Header.Size == ControlPing)
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }
+ else
+ {
+ // Data frame - write header first, then payload
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ PendingWrite& Payload = m_SendQueue.front();
+ m_Transport->AsyncWrite(
+ Payload.Data.data(),
+ Payload.Data.size(),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send payload error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }));
+ }
+}
+
+void
+AsyncComputeSocket::StartPingTimer()
+{
+ if (m_Closed)
+ {
+ return;
+ }
+
+ m_PingTimer.expires_after(std::chrono::seconds(2));
+
+ auto Self = shared_from_this();
+ m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) {
+ if (Ec || m_Closed)
{
- // Ping response - ignore
+ return;
}
- else
+
+ // Enqueue a ping control frame
+ PendingWrite Write;
+ Write.Header.Channel = 0;
+ Write.Header.Size = ControlPing;
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
{
- ZEN_WARN("invalid frame header size: {}", Header.Size);
- return;
+ FlushNextSend();
}
- }
+
+ StartPingTimer();
+ }));
}
void
-ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader)
+AsyncComputeSocket::HandleError()
{
- // Each channel has its own send thread. All send threads share m_SendMutex
- // to serialize writes to the transport, since TCP requires atomic frame writes.
- FrameHeader Header;
- Header.Channel = Channel;
+ if (m_Closed)
+ {
+ return;
+ }
+
+ Close();
- const uint8_t* Data;
- while ((Data = Reader.WaitToRead(1)) != nullptr)
+ // Notify all channels that the connection is gone so agents can clean up
+ for (auto& [ChannelId, Handler] : m_DetachHandlers)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
+ if (Handler)
+ {
+ Handler();
+ }
+ }
+}
- Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize());
- m_Transport->SendMessage(&Header, sizeof(Header));
- m_Transport->SendMessage(Data, Header.Size);
- Reader.AdvanceReadPosition(Header.Size);
+void
+AsyncComputeSocket::Close()
+{
+ if (m_Closed)
+ {
+ return;
}
- if (Reader.IsComplete())
+ m_Closed = true;
+ m_PingTimer.cancel();
+
+ if (m_Transport)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- Header.Size = ControlDetach;
- m_Transport->SendMessage(&Header, sizeof(Header));
+ m_Transport->Close();
}
}
diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h
index 0c3cb4195..6c494603a 100644
--- a/src/zenhorde/hordecomputesocket.h
+++ b/src/zenhorde/hordecomputesocket.h
@@ -2,45 +2,74 @@
#pragma once
-#include "hordecomputebuffer.h"
-#include "hordecomputechannel.h"
#include "hordetransport.h"
#include <zencore/logbase.h>
-#include <condition_variable>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# undef SendMessage
+#endif
+
+#include <deque>
+#include <functional>
#include <memory>
-#include <mutex>
-#include <thread>
+#include <system_error>
#include <unordered_map>
#include <vector>
namespace zen::horde {
-/** Multiplexed socket that routes data between multiple channels over a single transport.
+class AsyncComputeTransport;
+
+/** Handler called when a data frame arrives for a channel. */
+using FrameHandler = std::function<void(std::vector<uint8_t> Data)>;
+
+/** Handler called when a channel is detached by the remote peer. */
+using DetachHandler = std::function<void()>;
+
+/** Handler for async send completion. */
+using SendHandler = std::function<void(const std::error_code&)>;
+
+/** Async multiplexed socket that routes data between channels over a single transport.
*
- * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers.
- * A recv thread demultiplexes incoming frames to channel-specific buffers, while
- * per-channel send threads multiplex outgoing data onto the shared transport.
+ * Uses an async recv pump, a serialized send queue, and a periodic ping timer -
+ * all running on a shared io_context.
*
- * Wire format per frame: [channelId (4B)][size (4B)][data]
- * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping.
+ * Wire format per frame: [channelId(4B)][size(4B)][data].
+ * Control messages use negative sizes: -2 = detach, -3 = ping.
*/
-class ComputeSocket
+class AsyncComputeSocket : public std::enable_shared_from_this<AsyncComputeSocket>
{
public:
- explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport);
- ~ComputeSocket();
+ AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext);
+ ~AsyncComputeSocket();
+
+ AsyncComputeSocket(const AsyncComputeSocket&) = delete;
+ AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete;
+
+ /** Register callbacks for a channel. Must be called before StartRecvPump(). */
+ void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach);
- ComputeSocket(const ComputeSocket&) = delete;
- ComputeSocket& operator=(const ComputeSocket&) = delete;
+ /** Begin the async recv pump and ping timer. */
+ void StartRecvPump();
- /** Create a channel with the given ID.
- * Allocates anonymous in-process buffers and spawns a send thread for the channel. */
- Ref<ComputeChannel> CreateChannel(int ChannelId);
+ /** Enqueue a data frame for async transmission. */
+ void AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler = {});
- /** Start the recv pump and ping threads. Must be called after all channels are created. */
- void StartCommunication();
+ /** Send a control frame (detach) for a channel. */
+ void AsyncSendDetach(int ChannelId, SendHandler Handler = {});
+
+ /** Close the transport and cancel all pending operations. */
+ void Close();
+
+ /** The strand on which all socket I/O callbacks run. Channels that need to serialize
+ * their own state with OnFrame/OnDetach (which are invoked from this strand) should
+ * bind their timers and async operations to it as well. */
+ asio::strand<asio::any_io_executor>& GetStrand() { return m_Strand; }
private:
struct FrameHeader
@@ -49,31 +78,35 @@ private:
int32_t Size = 0;
};
+ struct PendingWrite
+ {
+ FrameHeader Header;
+ std::vector<uint8_t> Data;
+ SendHandler Handler;
+ };
+
static constexpr int32_t ControlDetach = -2;
static constexpr int32_t ControlPing = -3;
LoggerRef Log() { return m_Log; }
- void RecvThreadProc();
- void SendThreadProc(int Channel, ComputeBufferReader Reader);
- void PingThreadProc();
-
- LoggerRef m_Log;
- std::unique_ptr<ComputeTransport> m_Transport;
- std::mutex m_SendMutex; ///< Serializes writes to the transport
-
- std::mutex m_WritersMutex;
- std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID
+ void DoRecvHeader();
+ void DoRecvPayload(FrameHeader Header);
+ void FlushNextSend();
+ void StartPingTimer();
+ void HandleError();
- std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction
- std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel
+ LoggerRef m_Log;
+ std::unique_ptr<AsyncComputeTransport> m_Transport;
+ asio::strand<asio::any_io_executor> m_Strand;
+ asio::steady_timer m_PingTimer;
- std::thread m_RecvThread;
- std::thread m_PingThread;
+ std::unordered_map<int, FrameHandler> m_FrameHandlers;
+ std::unordered_map<int, DetachHandler> m_DetachHandlers;
- bool m_PingShouldStop = false;
- std::mutex m_PingMutex;
- std::condition_variable m_PingCV;
+ FrameHeader m_RecvHeader;
+ std::deque<PendingWrite> m_SendQueue;
+ bool m_Closed = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp
index 2dca228d9..9f6125c64 100644
--- a/src/zenhorde/hordeconfig.cpp
+++ b/src/zenhorde/hordeconfig.cpp
@@ -1,5 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <zencore/logging.h>
+#include <zencore/string.h>
#include <zenhorde/hordeconfig.h>
namespace zen::horde {
@@ -9,12 +11,14 @@ HordeConfig::Validate() const
{
if (ServerUrl.empty())
{
+ ZEN_WARN("Horde server URL is not configured");
return false;
}
// Relay mode implies AES encryption
if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES)
{
+ ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode));
return false;
}
@@ -52,37 +56,39 @@ ToString(Encryption Enc)
bool
FromString(ConnectionMode& OutMode, std::string_view Str)
{
- if (Str == "direct")
+ if (StrCaseCompare(Str, "direct") == 0)
{
OutMode = ConnectionMode::Direct;
return true;
}
- if (Str == "tunnel")
+ if (StrCaseCompare(Str, "tunnel") == 0)
{
OutMode = ConnectionMode::Tunnel;
return true;
}
- if (Str == "relay")
+ if (StrCaseCompare(Str, "relay") == 0)
{
OutMode = ConnectionMode::Relay;
return true;
}
+ ZEN_WARN("unrecognized Horde connection mode: '{}'", Str);
return false;
}
bool
FromString(Encryption& OutEnc, std::string_view Str)
{
- if (Str == "none")
+ if (StrCaseCompare(Str, "none") == 0)
{
OutEnc = Encryption::None;
return true;
}
- if (Str == "aes")
+ if (StrCaseCompare(Str, "aes") == 0)
{
OutEnc = Encryption::AES;
return true;
}
+ ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str);
return false;
}
diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp
index f88c95da2..ea0ea1e83 100644
--- a/src/zenhorde/hordeprovisioner.cpp
+++ b/src/zenhorde/hordeprovisioner.cpp
@@ -6,49 +6,83 @@
#include "hordeagent.h"
#include "hordebundle.h"
+#include <zencore/compactbinary.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/scopeguard.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
+#include <zenhttp/httpclient.h>
+#include <zenutil/workerpools.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <algorithm>
#include <chrono>
#include <thread>
namespace zen::horde {
-struct HordeProvisioner::AgentWrapper
-{
- std::thread Thread;
- std::atomic<bool> ShouldExit{false};
-};
-
HordeProvisioner::HordeProvisioner(const HordeConfig& Config,
const std::filesystem::path& BinariesPath,
const std::filesystem::path& WorkingDir,
- std::string_view OrchestratorEndpoint)
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession,
+ bool CleanStart,
+ std::string_view TraceHost)
: m_Config(Config)
, m_BinariesPath(BinariesPath)
, m_WorkingDir(WorkingDir)
, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_CoordinatorSession(CoordinatorSession)
+, m_CleanStart(CleanStart)
+, m_TraceHost(TraceHost)
, m_Log(zen::logging::Get("horde.provisioner"))
{
+ m_IoContext = std::make_unique<asio::io_context>();
+
+ auto Work = asio::make_work_guard(*m_IoContext);
+ for (int i = 0; i < IoThreadCount; ++i)
+ {
+ m_IoThreads.emplace_back([this, i, Work] {
+ zen::SetCurrentThreadName(fmt::format("horde_io_{}", i));
+ m_IoContext->run();
+ });
+ }
}
HordeProvisioner::~HordeProvisioner()
{
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- for (auto& Agent : m_Agents)
+ m_AskForAgents.store(false);
+ m_ShutdownEvent.Set();
+
+ // Shut down async agents and io_context
{
- Agent->ShouldExit.store(true);
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (auto& Entry : m_AsyncAgents)
+ {
+ Entry.Agent->Cancel();
+ }
+ m_AsyncAgents.clear();
}
- for (auto& Agent : m_Agents)
+
+ m_IoContext->stop();
+
+ for (auto& Thread : m_IoThreads)
{
- if (Agent->Thread.joinable())
+ if (Thread.joinable())
{
- Agent->Thread.join();
+ Thread.join();
}
}
+
+ // Wait for all pool work items to finish before destroying members they reference
+ if (m_PendingWorkItems.load() > 0)
+ {
+ m_AllWorkDone.Wait();
+ }
}
void
@@ -56,9 +90,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count)
{
ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount");
- m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)));
+ const uint32_t ClampedCount = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores));
+ const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount);
+
+ if (ClampedCount != PreviousTarget)
+ {
+ ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})",
+ PreviousTarget,
+ ClampedCount,
+ m_ActiveCoreCount.load(),
+ m_EstimatedCoreCount.load());
+ }
- while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load())
+ // Only provision if the gap is at least one agent-sized chunk. Without
+ // this, draining a 32-core agent to cover a 28-core excess would leave a
+ // 4-core gap that triggers a 32-core provision, which triggers another
+ // drain, ad infinitum.
+ while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load())
{
if (!m_AskForAgents.load())
{
@@ -67,21 +115,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count)
RequestAgent();
}
- // Clean up finished agent threads
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- for (auto It = m_Agents.begin(); It != m_Agents.end();)
+ // Scale down async agents
{
- if ((*It)->ShouldExit.load())
+ std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock);
+
+ uint32_t AsyncActive = m_ActiveCoreCount.load();
+ uint32_t AsyncTarget = m_TargetCoreCount.load();
+
+ uint32_t AlreadyDrainingCores = 0;
+ for (const auto& Entry : m_AsyncAgents)
{
- if ((*It)->Thread.joinable())
+ if (Entry.Draining)
{
- (*It)->Thread.join();
+ AlreadyDrainingCores += Entry.CoreCount;
}
- It = m_Agents.erase(It);
}
- else
+
+ uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0;
+
+ if (EffectiveAsync > AsyncTarget)
{
- ++It;
+ struct Candidate
+ {
+ AsyncAgentEntry* Entry;
+ int Workload;
+ };
+ std::vector<Candidate> Candidates;
+
+ for (auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.Draining || Entry.RemoteEndpoint.empty())
+ {
+ continue;
+ }
+
+ int Workload = 0;
+ bool Reachable = false;
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.drain";
+ Settings.ConnectTimeout = std::chrono::milliseconds{2000};
+ Settings.Timeout = std::chrono::milliseconds{3000};
+ try
+ {
+ HttpClient Client(Entry.RemoteEndpoint, Settings);
+ HttpClient::Response Resp = Client.Get("/compute/session/status");
+ if (Resp.IsSuccess())
+ {
+ CbObject Status = Resp.AsObject();
+ Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0);
+ Reachable = true;
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what());
+ }
+
+ if (Reachable)
+ {
+ Candidates.push_back({&Entry, Workload});
+ }
+ }
+
+ const uint32_t ExcessCores = EffectiveAsync - AsyncTarget;
+ uint32_t CoresDrained = 0;
+
+ while (CoresDrained < ExcessCores && !Candidates.empty())
+ {
+ const uint32_t Remaining = ExcessCores - CoresDrained;
+
+ Candidates.erase(std::remove_if(Candidates.begin(),
+ Candidates.end(),
+ [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }),
+ Candidates.end());
+
+ if (Candidates.empty())
+ {
+ break;
+ }
+
+ Candidate* Best = &Candidates[0];
+ for (auto& C : Candidates)
+ {
+ if (C.Entry->CoreCount > Best->Entry->CoreCount ||
+ (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload))
+ {
+ Best = &C;
+ }
+ }
+
+ ZEN_INFO("draining async agent lease={} ({} cores, workload={})",
+ Best->Entry->LeaseId,
+ Best->Entry->CoreCount,
+ Best->Workload);
+
+ DrainAsyncAgent(*Best->Entry);
+ CoresDrained += Best->Entry->CoreCount;
+
+ AsyncAgentEntry* Drained = Best->Entry;
+ Candidates.erase(
+ std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }),
+ Candidates.end());
+ }
}
}
}
@@ -101,266 +236,395 @@ HordeProvisioner::GetStats() const
uint32_t
HordeProvisioner::GetAgentCount() const
{
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- return static_cast<uint32_t>(m_Agents.size());
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ return static_cast<uint32_t>(m_AsyncAgents.size());
}
-void
-HordeProvisioner::RequestAgent()
+compute::AgentProvisioningStatus
+HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const
{
- m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
+ // Worker IDs are "horde-{LeaseId}" - strip the prefix to match lease ID
+ constexpr std::string_view Prefix = "horde-";
+ if (!WorkerId.starts_with(Prefix))
+ {
+ return compute::AgentProvisioningStatus::Unknown;
+ }
+ std::string_view LeaseId = WorkerId.substr(Prefix.size());
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
+ std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock);
+ for (const auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.LeaseId == LeaseId)
+ {
+ if (Entry.Draining)
+ {
+ return compute::AgentProvisioningStatus::Draining;
+ }
+ return compute::AgentProvisioningStatus::Active;
+ }
+ }
- auto Wrapper = std::make_unique<AgentWrapper>();
- AgentWrapper& Ref = *Wrapper;
- Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); });
+ // Check recently-drained agents that have already been cleaned up
+ std::string WorkerIdStr(WorkerId);
+ if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0)
+ {
+ // Also remove from the ordering queue so size accounting stays consistent.
+ auto It = std::find(m_RecentlyDrainedOrder.begin(), m_RecentlyDrainedOrder.end(), WorkerIdStr);
+ if (It != m_RecentlyDrainedOrder.end())
+ {
+ m_RecentlyDrainedOrder.erase(It);
+ }
+ return compute::AgentProvisioningStatus::Draining;
+ }
- m_Agents.push_back(std::move(Wrapper));
+ return compute::AgentProvisioningStatus::Unknown;
}
-void
-HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper)
+std::vector<std::string>
+HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const
{
- ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent");
+ std::vector<std::string> Args;
+ Args.emplace_back("compute");
+ Args.emplace_back("--http=asio");
+ Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
+ Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen");
- static std::atomic<uint32_t> ThreadIndex{0};
- const uint32_t CurrentIndex = ThreadIndex.fetch_add(1);
+ if (m_CleanStart)
+ {
+ Args.emplace_back("--clean");
+ }
- zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex));
+ if (!m_OrchestratorEndpoint.empty())
+ {
+ ExtendableStringBuilder<256> CoordArg;
+ CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
+ Args.emplace_back(CoordArg.ToView());
+ }
- std::unique_ptr<HordeAgent> Agent;
- uint32_t MachineCoreCount = 0;
+ {
+ ExtendableStringBuilder<128> IdArg;
+ IdArg << "--instance-id=horde-" << Machine.LeaseId;
+ Args.emplace_back(IdArg.ToView());
+ }
- auto _ = MakeGuard([&] {
- if (Agent)
- {
- Agent->CloseConnection();
- }
- Wrapper.ShouldExit.store(true);
- });
+ if (!m_CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << m_CoordinatorSession;
+ Args.emplace_back(SessionArg.ToView());
+ }
+ if (!m_TraceHost.empty())
{
- // EstimatedCoreCount is incremented speculatively when the agent is requested
- // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
- auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << m_TraceHost;
+ Args.emplace_back(TraceArg.ToView());
+ }
+ // In relay mode, the remote zenserver's local address is not reachable from the
+ // orchestrator. Pass the relay-visible endpoint so it announces the correct URL.
+ if (Machine.Mode == ConnectionMode::Relay)
+ {
+ const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (Addr.find(':') != std::string::npos)
+ {
+ Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port));
+ }
+ else
{
- ZEN_TRACE_CPU("HordeProvisioner::CreateBundles");
+ Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port));
+ }
+ }
- std::lock_guard<std::mutex> BundleLock(m_BundleLock);
+ return Args;
+}
- if (!m_BundlesCreated)
- {
- const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+bool
+HordeProvisioner::InitializeHordeClient()
+{
+ ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient");
+
+ std::lock_guard<std::mutex> BundleLock(m_BundleLock);
- std::vector<BundleFile> Files;
+ if (!m_BundlesCreated)
+ {
+ const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+
+ std::vector<BundleFile> Files;
#if ZEN_PLATFORM_WINDOWS
- Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.pdb", true);
#elif ZEN_PLATFORM_LINUX
- Files.emplace_back(m_BinariesPath / "zenserver", false);
- Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
#elif ZEN_PLATFORM_MAC
- Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
#endif
- BundleResult Result;
- if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
- {
- ZEN_WARN("failed to create bundle, cannot provision any agents!");
- m_AskForAgents.store(false);
- return;
- }
-
- m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
- m_BundlesCreated = true;
- }
-
- if (!m_HordeClient)
- {
- m_HordeClient = std::make_unique<HordeClient>(m_Config);
- if (!m_HordeClient->Initialize())
- {
- ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
- m_AskForAgents.store(false);
- return;
- }
- }
+ BundleResult Result;
+ if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
+ {
+ ZEN_WARN("failed to create bundle, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ m_ShutdownEvent.Set();
+ return false;
}
- if (!m_AskForAgents.load())
+ m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
+ m_BundlesCreated = true;
+ }
+
+ if (!m_HordeClient)
+ {
+ m_HordeClient = std::make_unique<HordeClient>(m_Config);
+ if (!m_HordeClient->Initialize())
{
- return;
+ ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ m_ShutdownEvent.Set();
+ return false;
}
+ }
- m_AgentsRequesting.fetch_add(1);
- auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
+ return true;
+}
- // Simple backoff: if the last machine request failed, wait up to 5 seconds
- // before trying again.
- //
- // Note however that it's possible that multiple threads enter this code at
- // the same time if multiple agents are requested at once, and they will all
- // see the same last failure time and back off accordingly. We might want to
- // use a semaphore or similar to limit the number of concurrent requests.
+void
+HordeProvisioner::RequestAgent()
+{
+ m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
- if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
- {
- auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
- const uint64_t ElapsedNs = Now - LastFail;
- const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
- if (ElapsedMs < 5000)
- {
- const uint64_t WaitMs = 5000 - ElapsedMs;
- for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100)
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- }
+ if (m_PendingWorkItems.fetch_add(1) == 0)
+ {
+ m_AllWorkDone.Reset();
+ }
- if (Wrapper.ShouldExit.load())
+ GetSmallWorkerPool(EWorkloadType::Background)
+ .ScheduleWork(
+ [this] {
+ ProvisionAgent();
+ if (m_PendingWorkItems.fetch_sub(1) == 1)
{
- return;
+ m_AllWorkDone.Set();
}
- }
- }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+}
- if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
- {
- return;
- }
+void
+HordeProvisioner::ProvisionAgent()
+{
+ ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent");
+
+ // EstimatedCoreCount is incremented speculatively when the agent is requested
+ // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
+ auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
- std::string RequestBody = m_HordeClient->BuildRequestBody();
+ if (!InitializeHordeClient())
+ {
+ return;
+ }
- // Resolve cluster if needed
- std::string ClusterId = m_Config.Cluster;
- if (ClusterId == HordeConfig::ClusterAuto)
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
+
+ m_AgentsRequesting.fetch_add(1);
+ auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
+
+ // Simple backoff: if the last machine request failed, wait up to 5 seconds
+ // before trying again.
+ //
+ // Note however that it's possible that multiple threads enter this code at
+ // the same time if multiple agents are requested at once, and they will all
+ // see the same last failure time and back off accordingly. We might want to
+ // use a semaphore or similar to limit the number of concurrent requests.
+
+ if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
+ {
+ auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
+ const uint64_t ElapsedNs = Now - LastFail;
+ const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
+ if (ElapsedMs < 5000)
{
- ClusterInfo Cluster;
- if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
+ // Wait on m_ShutdownEvent so shutdown wakes this pool thread immediately instead
+ // of stalling for up to 5s in 100ms sleep chunks. Wait() returns true iff the
+ // event was signaled (shutdown); false means the backoff elapsed normally.
+ const uint64_t WaitMs = 5000 - ElapsedMs;
+ if (m_ShutdownEvent.Wait(static_cast<int>(WaitMs)))
{
- ZEN_WARN("failed to resolve cluster");
- m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
return;
}
- ClusterId = Cluster.ClusterId;
}
+ }
- MachineInfo Machine;
- if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
+ if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
+ {
+ return;
+ }
+
+ std::string RequestBody = m_HordeClient->BuildRequestBody();
+
+ // Resolve cluster if needed
+ std::string ClusterId = m_Config.Cluster;
+ if (ClusterId == HordeConfig::ClusterAuto)
+ {
+ ClusterInfo Cluster;
+ if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
{
+ ZEN_WARN("failed to resolve cluster");
m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
return;
}
+ ClusterId = Cluster.ClusterId;
+ }
- m_LastRequestFailTime.store(0);
+ ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})",
+ ClusterId.empty() ? "default" : ClusterId.c_str(),
+ m_ActiveCoreCount.load(),
+ m_TargetCoreCount.load());
- if (Wrapper.ShouldExit.load())
- {
- return;
- }
+ MachineInfo Machine;
+ if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
+ {
+ m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
+ return;
+ }
- // Connect to agent and perform handshake
- Agent = std::make_unique<HordeAgent>(Machine);
- if (!Agent->IsValid())
- {
- ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort());
- return;
- }
+ m_LastRequestFailTime.store(0);
- if (!Agent->BeginCommunication())
- {
- ZEN_WARN("BeginCommunication failed");
- return;
- }
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
- for (auto& [Locator, BundleDir] : m_Bundles)
+ AsyncAgentConfig AgentConfig;
+ AgentConfig.Machine = Machine;
+ AgentConfig.Bundles = m_Bundles;
+ AgentConfig.Args = BuildAgentArgs(Machine);
+
+#if ZEN_PLATFORM_WINDOWS
+ AgentConfig.UseWine = !Machine.IsWindows;
+ AgentConfig.Executable = "zenserver.exe";
+#else
+ AgentConfig.UseWine = false;
+ AgentConfig.Executable = "zenserver";
+#endif
+
+ auto AsyncAgent = std::make_shared<AsyncHordeAgent>(*m_IoContext);
+
+ AsyncAgentEntry Entry;
+ Entry.Agent = AsyncAgent;
+ Entry.LeaseId = Machine.LeaseId;
+ Entry.CoreCount = Machine.LogicalCores;
+
+ const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (EndpointAddr.find(':') != std::string::npos)
+ {
+ Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort);
+ }
+ else
+ {
+ Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort);
+ }
+
+ {
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ m_AsyncAgents.push_back(std::move(Entry));
+ }
+
+ AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) {
+ if (Result.CoreCount > 0)
{
- if (Wrapper.ShouldExit.load())
+ // Only subtract estimated cores if not already subtracted by DrainAsyncAgent
+ bool WasDraining = false;
{
- return;
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (const auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.Agent == AsyncAgent)
+ {
+ WasDraining = Entry.Draining;
+ break;
+ }
+ }
}
- if (!Agent->UploadBinaries(BundleDir, Locator))
+ if (!WasDraining)
{
- ZEN_WARN("UploadBinaries failed");
- return;
+ m_EstimatedCoreCount.fetch_sub(Result.CoreCount);
}
+ m_ActiveCoreCount.fetch_sub(Result.CoreCount);
+ m_AgentsActive.fetch_sub(1);
}
+ OnAsyncAgentDone(AsyncAgent);
+ });
- if (Wrapper.ShouldExit.load())
- {
- return;
- }
-
- // Build command line for remote zenserver
- std::vector<std::string> ArgStrings;
- ArgStrings.push_back("compute");
- ArgStrings.push_back("--http=asio");
+ // Track active cores (estimated was already added by RequestAgent)
+ m_EstimatedCoreCount.fetch_add(Machine.LogicalCores);
+ m_ActiveCoreCount.fetch_add(Machine.LogicalCores);
+ m_AgentsActive.fetch_add(1);
+}
- // TEMP HACK - these should be made fully dynamic
- // these are currently here to allow spawning the compute agent locally
- // for debugging purposes (i.e with a local Horde Server+Agent setup)
- ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
- ArgStrings.push_back("--data-dir=c:\\temp\\123");
+void
+HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry)
+{
+ Entry.Draining = true;
+ m_EstimatedCoreCount.fetch_sub(Entry.CoreCount);
+ m_AgentsDraining.fetch_add(1);
- if (!m_OrchestratorEndpoint.empty())
- {
- ExtendableStringBuilder<256> CoordArg;
- CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
- ArgStrings.emplace_back(CoordArg.ToView());
- }
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.drain";
+ Settings.ConnectTimeout = std::chrono::milliseconds{5000};
+ Settings.Timeout = std::chrono::milliseconds{10000};
- {
- ExtendableStringBuilder<128> IdArg;
- IdArg << "--instance-id=horde-" << Machine.LeaseId;
- ArgStrings.emplace_back(IdArg.ToView());
- }
+ try
+ {
+ HttpClient Client(Entry.RemoteEndpoint, Settings);
- std::vector<const char*> Args;
- Args.reserve(ArgStrings.size());
- for (const std::string& Arg : ArgStrings)
+ HttpClient::Response Response = Client.Post("/compute/session/drain");
+ if (!Response.IsSuccess())
{
- Args.push_back(Arg.c_str());
+ ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast<int>(Response.StatusCode));
+ return;
}
-#if ZEN_PLATFORM_WINDOWS
- const bool UseWine = !Machine.IsWindows;
- const char* AppName = "zenserver.exe";
-#else
- const bool UseWine = false;
- const char* AppName = "zenserver";
-#endif
-
- Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine);
-
- ZEN_INFO("remote execution started on [{}:{}] lease={}",
- Machine.GetConnectionAddress(),
- Machine.GetConnectionPort(),
- Machine.LeaseId);
-
- MachineCoreCount = Machine.LogicalCores;
- m_EstimatedCoreCount.fetch_add(MachineCoreCount);
- m_ActiveCoreCount.fetch_add(MachineCoreCount);
- m_AgentsActive.fetch_add(1);
+ ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId);
+ (void)Client.Post("/compute/session/sunset");
}
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what());
+ }
+}
- // Agent poll loop
-
- auto ActiveGuard = MakeGuard([&]() {
- m_EstimatedCoreCount.fetch_sub(MachineCoreCount);
- m_ActiveCoreCount.fetch_sub(MachineCoreCount);
- m_AgentsActive.fetch_sub(1);
- });
-
- while (Agent->IsValid() && !Wrapper.ShouldExit.load())
+void
+HordeProvisioner::OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent)
+{
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It)
{
- const bool LogOutput = false;
- if (!Agent->Poll(LogOutput))
+ if (It->Agent == Agent)
{
+ if (It->Draining)
+ {
+ m_AgentsDraining.fetch_sub(1);
+ std::string WorkerId = "horde-" + It->LeaseId;
+ if (m_RecentlyDrainedWorkerIds.insert(WorkerId).second)
+ {
+ m_RecentlyDrainedOrder.push_back(WorkerId);
+ while (m_RecentlyDrainedOrder.size() > RecentlyDrainedCapacity)
+ {
+ m_RecentlyDrainedWorkerIds.erase(m_RecentlyDrainedOrder.front());
+ m_RecentlyDrainedOrder.pop_front();
+ }
+ }
+ }
+ m_AsyncAgents.erase(It);
break;
}
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
}
diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp
index 69766e73e..65eaea477 100644
--- a/src/zenhorde/hordetransport.cpp
+++ b/src/zenhorde/hordetransport.cpp
@@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
-#if ZEN_PLATFORM_WINDOWS
-# undef SendMessage
-#endif
-
namespace zen::horde {
-// ComputeTransport base
+// --- AsyncTcpComputeTransport ---
-bool
-ComputeTransport::SendMessage(const void* Data, size_t Size)
+struct AsyncTcpComputeTransport::Impl
{
- const uint8_t* Ptr = static_cast<const uint8_t*>(Data);
- size_t Remaining = Size;
-
- while (Remaining > 0)
- {
- const size_t Sent = Send(Ptr, Remaining);
- if (Sent == 0)
- {
- return false;
- }
- Ptr += Sent;
- Remaining -= Sent;
- }
+ asio::io_context& IoContext;
+ asio::ip::tcp::socket Socket;
- return true;
-}
+ explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {}
+};
-bool
-ComputeTransport::RecvMessage(void* Data, size_t Size)
+AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext)
+: m_Impl(std::make_unique<Impl>(IoContext))
+, m_Log(zen::logging::Get("horde.transport.async"))
{
- uint8_t* Ptr = static_cast<uint8_t*>(Data);
- size_t Remaining = Size;
-
- while (Remaining > 0)
- {
- const size_t Received = Recv(Ptr, Remaining);
- if (Received == 0)
- {
- return false;
- }
- Ptr += Received;
- Remaining -= Received;
- }
-
- return true;
}
-// TcpComputeTransport - ASIO pimpl
-
-struct TcpComputeTransport::Impl
+AsyncTcpComputeTransport::~AsyncTcpComputeTransport()
{
- asio::io_context IoContext;
- asio::ip::tcp::socket Socket;
-
- Impl() : Socket(IoContext) {}
-};
+ Close();
+}
-// Uses ASIO in synchronous mode only — no async operations or io_context::run().
-// The io_context is only needed because ASIO sockets require one to be constructed.
-TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
-: m_Impl(std::make_unique<Impl>())
-, m_Log(zen::logging::Get("horde.transport"))
+void
+AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler)
{
- ZEN_TRACE_CPU("TcpComputeTransport::Connect");
+ ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect");
asio::error_code Ec;
@@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
{
ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message());
m_HasErrors = true;
+ asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); });
return;
}
const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort());
- m_Impl->Socket.connect(Endpoint, Ec);
- if (Ec)
- {
- ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message());
- m_HasErrors = true;
- return;
- }
+ // Copy the nonce so it survives past this scope into the async callback
+ auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize);
- // Disable Nagle's algorithm for lower latency
- m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec);
-}
+ m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("async connect failed: {}", Ec.message());
+ m_HasErrors = true;
+ Handler(Ec);
+ return;
+ }
-TcpComputeTransport::~TcpComputeTransport()
-{
- Close();
+ asio::error_code SetOptEc;
+ m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc);
+
+ // Send the 64-byte nonce as the first thing on the wire
+ asio::async_write(m_Impl->Socket,
+ asio::buffer(*NonceBuf),
+ [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) {
+ if (Ec)
+ {
+ ZEN_WARN("nonce write failed: {}", Ec.message());
+ m_HasErrors = true;
+ }
+ Handler(Ec);
+ });
+ });
}
bool
-TcpComputeTransport::IsValid() const
+AsyncTcpComputeTransport::IsValid() const
{
return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed;
}
-size_t
-TcpComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
- }
-
- asio::error_code Ec;
- const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec);
-
- if (Ec)
- {
- m_HasErrors = true;
- return 0;
+ asio::post(m_Impl->IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- return Sent;
+ asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}
-size_t
-TcpComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
- }
-
- asio::error_code Ec;
- const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec);
-
- if (Ec)
- {
- return 0;
+ asio::post(m_Impl->IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- return Received;
-}
-
-void
-TcpComputeTransport::MarkComplete()
-{
+ asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}
void
-TcpComputeTransport::Close()
+AsyncTcpComputeTransport::Close()
{
if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open())
{
diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h
index 1b178dc0f..b5e841d7a 100644
--- a/src/zenhorde/hordetransport.h
+++ b/src/zenhorde/hordetransport.h
@@ -8,55 +8,60 @@
#include <cstddef>
#include <cstdint>
+#include <functional>
#include <memory>
+#include <system_error>
-#if ZEN_PLATFORM_WINDOWS
-# undef SendMessage
-#endif
+namespace asio {
+class io_context;
+}
namespace zen::horde {
-/** Abstract base interface for compute transports.
+/** Handler types for async transport operations. */
+using AsyncConnectHandler = std::function<void(const std::error_code&)>;
+using AsyncIoHandler = std::function<void(const std::error_code&, size_t)>;
+
+/** Abstract base for asynchronous compute transports.
*
- * Matches the UE FComputeTransport pattern. Concrete implementations handle
- * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides
- * blocking message helpers on top.
+ * All callbacks are invoked on the io_context that was provided at construction.
+ * Callers are responsible for strand serialization if needed.
*/
-class ComputeTransport
+class AsyncComputeTransport
{
public:
- virtual ~ComputeTransport() = default;
+ virtual ~AsyncComputeTransport() = default;
+
+ virtual bool IsValid() const = 0;
- virtual bool IsValid() const = 0;
- virtual size_t Send(const void* Data, size_t Size) = 0;
- virtual size_t Recv(void* Data, size_t Size) = 0;
- virtual void MarkComplete() = 0;
- virtual void Close() = 0;
+ /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */
+ virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0;
- /** Blocking send that loops until all bytes are transferred. Returns false on error. */
- bool SendMessage(const void* Data, size_t Size);
+ /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */
+ virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0;
- /** Blocking receive that loops until all bytes are transferred. Returns false on error. */
- bool RecvMessage(void* Data, size_t Size);
+ virtual void Close() = 0;
};
-/** TCP socket transport using ASIO.
+/** Async TCP transport using ASIO.
*
- * Connects to the Horde compute endpoint specified by MachineInfo and provides
- * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the
- * header clean.
+ * Connects to the Horde compute endpoint and provides async send/receive.
+ * The socket is created on a caller-provided io_context (shared across agents).
*/
-class TcpComputeTransport final : public ComputeTransport
+class AsyncTcpComputeTransport final : public AsyncComputeTransport
{
public:
- explicit TcpComputeTransport(const MachineInfo& Info);
- ~TcpComputeTransport() override;
-
- bool IsValid() const override;
- size_t Send(const void* Data, size_t Size) override;
- size_t Recv(void* Data, size_t Size) override;
- void MarkComplete() override;
- void Close() override;
+ /** Construct a transport on the given io_context. Does not connect yet. */
+ explicit AsyncTcpComputeTransport(asio::io_context& IoContext);
+ ~AsyncTcpComputeTransport() override;
+
+ /** Asynchronously connect to the endpoint and send the nonce. */
+ void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler);
+
+ bool IsValid() const override;
+ void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void Close() override;
private:
LoggerRef Log() { return m_Log; }
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
index 505b6bde7..0b94a4397 100644
--- a/src/zenhorde/hordetransportaes.cpp
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -5,9 +5,12 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <algorithm>
#include <cstring>
-#include <random>
#if ZEN_PLATFORM_WINDOWS
# include <zencore/windows.h>
@@ -22,315 +25,410 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-struct AesComputeTransport::CryptoContext
-{
- uint8_t Key[KeySize] = {};
- uint8_t EncryptNonce[NonceBytes] = {};
- uint8_t DecryptNonce[NonceBytes] = {};
- bool HasErrors = false;
-
-#if !ZEN_PLATFORM_WINDOWS
- EVP_CIPHER_CTX* EncCtx = nullptr;
- EVP_CIPHER_CTX* DecCtx = nullptr;
-#endif
-
- CryptoContext(const uint8_t (&InKey)[KeySize])
- {
- memcpy(Key, InKey, KeySize);
-
- // The encrypt nonce is randomly initialized and then deterministically mutated
- // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
- // the wire (each received message carries its own nonce in the header).
- std::random_device Rd;
- std::mt19937 Gen(Rd());
- std::uniform_int_distribution<int> Dist(0, 255);
- for (auto& Byte : EncryptNonce)
- {
- Byte = static_cast<uint8_t>(Dist(Gen));
- }
-
-#if !ZEN_PLATFORM_WINDOWS
- // Drain any stale OpenSSL errors
- while (ERR_get_error() != 0)
- {
- }
-
- EncCtx = EVP_CIPHER_CTX_new();
- EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
-
- DecCtx = EVP_CIPHER_CTX_new();
- EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
-#endif
- }
+namespace {
- ~CryptoContext()
- {
-#if ZEN_PLATFORM_WINDOWS
- SecureZeroMemory(Key, sizeof(Key));
- SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
- SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
-#else
- OPENSSL_cleanse(Key, sizeof(Key));
- OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
- OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
-
- if (EncCtx)
- {
- EVP_CIPHER_CTX_free(EncCtx);
- }
- if (DecCtx)
- {
- EVP_CIPHER_CTX_free(DecCtx);
- }
-#endif
- }
+ static constexpr size_t AesNonceBytes = 12;
+ static constexpr size_t AesTagBytes = 16;
- void UpdateNonce()
+ /** AES-256-GCM crypto context. Not exposed outside this translation unit. */
+ struct AesCryptoContext
{
- uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
- N32[0]++;
- N32[1]--;
- N32[2] = N32[0] ^ N32[1];
- }
+ static constexpr size_t NonceBytes = AesNonceBytes;
+ static constexpr size_t TagBytes = AesTagBytes;
- // Returns total encrypted message size, or 0 on failure
- // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
- int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
- {
- UpdateNonce();
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ uint64_t DecryptCounter = 0; ///< Sequence number of the next message to be decrypted (for diagnostics)
+ bool HasErrors = false;
- // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
- // caching but has some overhead. For our use case (relatively large, infrequent messages)
- // this is acceptable.
#if ZEN_PLATFORM_WINDOWS
BCRYPT_ALG_HANDLE hAlg = nullptr;
BCRYPT_KEY_HANDLE hKey = nullptr;
+#else
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
+#endif
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = EncryptNonce;
- AuthInfo.cbNonce = NonceBytes;
- uint8_t Tag[TagBytes] = {};
- AuthInfo.pbTag = Tag;
- AuthInfo.cbTag = TagBytes;
-
- ULONG CipherLen = 0;
- NTSTATUS Status =
- BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
-
- if (!BCRYPT_SUCCESS(Status))
+ AesCryptoContext(const uint8_t (&InKey)[KeySize])
{
- HasErrors = true;
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
- return 0;
- }
+ memcpy(Key, InKey, KeySize);
- // Write header: length + nonce
- memcpy(Out, &InLength, 4);
- memcpy(Out + 4, EncryptNonce, NonceBytes);
- // Write tag after ciphertext
- memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+ // EncryptNonce is zero-initialized (NIST SP 800-38D §8.2.1 deterministic
+ // construction): fixed_field = 0, counter starts at 0 and is incremented
+ // before each encryption by UpdateNonce(). No RNG is used here because
+ // std::random_device is not guaranteed to be a CSPRNG (historic MinGW,
+ // some WASI targets), and the deterministic construction does not need
+ // one as long as each session uses a unique key.
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+#if ZEN_PLATFORM_WINDOWS
+ NTSTATUS Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptOpenAlgorithmProvider failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ hAlg = nullptr;
+ HasErrors = true;
+ return;
+ }
+
+ Status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptSetProperty(BCRYPT_CHAIN_MODE_GCM) failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ HasErrors = true;
+ return;
+ }
+
+ Status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptGenerateSymmetricKey failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ hKey = nullptr;
+ HasErrors = true;
+ return;
+ }
#else
- if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
- {
- HasErrors = true;
- return 0;
+ while (ERR_get_error() != 0)
+ {
+ }
+
+ EncCtx = EVP_CIPHER_CTX_new();
+ DecCtx = EVP_CIPHER_CTX_new();
+ if (!EncCtx || !DecCtx)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_new failed");
+ HasErrors = true;
+ return;
+ }
+
+ if (EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return;
+ }
+
+ if (EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return;
+ }
+#endif
}
- int32_t Offset = 0;
- // Write length
- memcpy(Out + Offset, &InLength, 4);
- Offset += 4;
- // Write nonce
- memcpy(Out + Offset, EncryptNonce, NonceBytes);
- Offset += NonceBytes;
-
- // Encrypt
- int OutLen = 0;
- if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ ~AesCryptoContext()
{
- HasErrors = true;
- return 0;
+#if ZEN_PLATFORM_WINDOWS
+ if (hKey)
+ {
+ BCryptDestroyKey(hKey);
+ }
+ if (hAlg)
+ {
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ }
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+#else
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
}
- Offset += OutLen;
- // Finalize
- int FinalLen = 0;
- if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ void UpdateNonce()
{
+ // NIST SP 800-38D §8.2.1 deterministic construction:
+ // nonce = [fixed_field (4 bytes) || invocation_counter (8 bytes, big-endian)]
+ // The low 8 bytes are a strict monotonic counter starting at zero. On 2^64
+ // exhaustion the session is torn down (HasErrors) - never wrap, since a repeated
+ // (key, nonce) pair catastrophically breaks AES-GCM confidentiality and integrity.
+ for (int i = 11; i >= 4; --i)
+ {
+ if (++EncryptNonce[i] != 0)
+ {
+ return;
+ }
+ }
HasErrors = true;
- return 0;
}
- Offset += FinalLen;
- // Get tag
- if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
{
- HasErrors = true;
- return 0;
- }
- Offset += TagBytes;
-
- return Offset;
-#endif
- }
+ UpdateNonce();
+ if (HasErrors)
+ {
+ return 0;
+ }
- // Decrypt a message. Returns decrypted data length, or 0 on failure.
- // Input must be [ciphertext][tag], with nonce provided separately.
- int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
- {
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
- AuthInfo.cbNonce = NonceBytes;
- AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
- AuthInfo.cbTag = TagBytes;
-
- ULONG PlainLen = 0;
- NTSTATUS Status = BCryptDecrypt(hKey,
- (PUCHAR)CipherAndTag,
- (ULONG)DataLength,
- &AuthInfo,
- nullptr,
- 0,
- (PUCHAR)Out,
- (ULONG)DataLength,
- &PlainLen,
- 0);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
- return 0;
- }
-
- return static_cast<int32_t>(PlainLen);
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ // Tag is output-only on encrypt; BCryptEncrypt writes TagBytes bytes into it, so skip zero-init.
+ uint8_t Tag[TagBytes];
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ const NTSTATUS Status = BCryptEncrypt(hKey,
+ (PUCHAR)In,
+ (ULONG)InLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ Out + 4 + NonceBytes,
+ (ULONG)InLength,
+ &CipherLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptEncrypt failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ HasErrors = true;
+ return 0;
+ }
+
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
#else
- if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
-
- int OutLen = 0;
- if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
- {
- HasErrors = true;
- return 0;
+ // Reset per message so any stale state from a previous encrypt (e.g. partial
+ // completion after a prior error) cannot bleed into this operation. Re-bind
+ // the cipher/key; the IV is then set via the normal init call below.
+ if (EVP_CIPHER_CTX_reset(EncCtx) != 1 || EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_reset/EncryptInit failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptInit_ex(key+iv) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptUpdate failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptFinal_ex failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ ZEN_ERROR("EVP_CTRL_GCM_GET_TAG failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
}
- // Set the tag for verification
- if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
{
- HasErrors = true;
- return 0;
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ const NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ // STATUS_AUTH_TAG_MISMATCH (0xC000A002) indicates GCM integrity failure -
+ // either in-flight corruption or active tampering. Log distinctly from
+ // other BCryptDecrypt failures so that tamper attempts are auditable.
+ static constexpr NTSTATUS STATUS_AUTH_TAG_MISMATCH_VAL = static_cast<NTSTATUS>(0xC000A002L);
+ if (Status == STATUS_AUTH_TAG_MISMATCH_VAL)
+ {
+ ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter);
+ }
+ else
+ {
+ ZEN_ERROR("BCryptDecrypt failed: 0x{:08x} (seq={})", static_cast<uint32_t>(Status), DecryptCounter);
+ }
+ HasErrors = true;
+ return 0;
+ }
+
+ ++DecryptCounter;
+ return static_cast<int32_t>(PlainLen);
+#else
+ // Same rationale as EncryptMessage: reset the context and re-bind the cipher
+ // before each decrypt to avoid stale state from a previous operation.
+ if (EVP_CIPHER_CTX_reset(DecCtx) != 1 || EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_reset/DecryptInit failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptInit_ex (seq={}) failed: {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptUpdate failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ ZEN_ERROR("EVP_CTRL_GCM_SET_TAG failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ // EVP_DecryptFinal_ex returns 0 specifically on GCM tag verification failure
+ // once the tag has been set. Log distinctly so tamper attempts are auditable.
+ ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter);
+ HasErrors = true;
+ return 0;
+ }
+
+ ++DecryptCounter;
+ return OutLen + FinalLen;
+#endif
}
+ };
- int FinalLen = 0;
- if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
- }
+} // anonymous namespace
- return OutLen + FinalLen;
-#endif
- }
+struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext
+{
+ using AesCryptoContext::AesCryptoContext;
};
-AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+// --- AsyncAesComputeTransport ---
+
+AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext)
: m_Crypto(std::make_unique<CryptoContext>(Key))
, m_Inner(std::move(InnerTransport))
+, m_IoContext(IoContext)
{
}
-AesComputeTransport::~AesComputeTransport()
+AsyncAesComputeTransport::~AsyncAesComputeTransport()
{
Close();
}
bool
-AesComputeTransport::IsValid() const
+AsyncAesComputeTransport::IsValid() const
{
return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
}
-size_t
-AesComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
- ZEN_TRACE_CPU("AesComputeTransport::Send");
-
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- std::lock_guard<std::mutex> Lock(m_Lock);
-
const int32_t DataLength = static_cast<int32_t>(Size);
- const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+ const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes;
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
+ // Encrypt directly into the per-write buffer rather than a long-lived member. Using a
+ // member (plaintext + ciphertext share that buffer during encryption on the OpenSSL
+ // path) would leave plaintext on the heap indefinitely and would also make the
+ // transport unsafe if AsyncWrite were ever invoked concurrently. Size the shared_ptr
+ // exactly to EncryptedLen afterwards.
+ auto EncBuf = std::make_shared<std::vector<uint8_t>>(MessageLength);
- const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
+ const int32_t EncryptedLen = m_Crypto->EncryptMessage(EncBuf->data(), Data, DataLength);
if (EncryptedLen == 0)
{
- return 0;
+ asio::post(m_IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); });
+ return;
}
- if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
- {
- return 0;
- }
+ EncBuf->resize(static_cast<size_t>(EncryptedLen));
- return Size;
+ m_Inner->AsyncWrite(
+ EncBuf->data(),
+ EncBuf->size(),
+ [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); });
}
-size_t
-AesComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
- // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
- // and returned on subsequent Recv calls without another decryption round-trip.
- ZEN_TRACE_CPU("AesComputeTransport::Recv");
-
- std::lock_guard<std::mutex> Lock(m_Lock);
+ uint8_t* Dest = static_cast<uint8_t*>(Data);
if (!m_RemainingData.empty())
{
const size_t Available = m_RemainingData.size() - m_RemainingOffset;
const size_t ToCopy = std::min(Available, Size);
- memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy);
m_RemainingOffset += ToCopy;
if (m_RemainingOffset >= m_RemainingData.size())
@@ -339,82 +437,104 @@ AesComputeTransport::Recv(void* Data, size_t Size)
m_RemainingOffset = 0;
}
- return ToCopy;
- }
-
- // Receive packet header: [length(4B)][nonce(12B)]
- struct PacketHeader
- {
- int32_t DataLength = 0;
- uint8_t Nonce[NonceBytes] = {};
- } Header;
-
- if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
- {
- return 0;
- }
-
- // Validate DataLength to prevent OOM from malicious/corrupt peers
- static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
-
- if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
- {
- ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
- return 0;
- }
-
- // Receive ciphertext + tag
- const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
-
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
-
- if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
- {
- return 0;
- }
-
- // Decrypt
- const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
-
- // We need a temporary buffer for decryption if we can't decrypt directly into output
- std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
-
- const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
- if (Decrypted == 0)
- {
- return 0;
- }
-
- memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+ if (ToCopy == Size)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); });
+ return;
+ }
- // Store remaining data if we couldn't return everything
- if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
- {
- m_RemainingOffset = 0;
- m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler));
+ return;
}
- return BytesToReturn;
+ DoRecvMessage(Dest, Size, std::move(Handler));
}
void
-AesComputeTransport::MarkComplete()
+AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler)
{
- if (IsValid())
- {
- m_Inner->MarkComplete();
- }
+ static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes;
+ auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>();
+
+ m_Inner->AsyncRead(HeaderBuf->data(),
+ HeaderSize,
+ [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ int32_t DataLength = 0;
+ memcpy(&DataLength, HeaderBuf->data(), 4);
+
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024;
+ if (DataLength <= 0 || DataLength > MaxDataLength)
+ {
+ Handler(asio::error::make_error_code(asio::error::invalid_argument), 0);
+ return;
+ }
+
+ const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes;
+ if (m_DecryptBuffer.size() < MessageLength)
+ {
+ m_DecryptBuffer.resize(MessageLength);
+ }
+
+ auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>();
+ memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes);
+
+ m_Inner->AsyncRead(
+ m_DecryptBuffer.data(),
+ MessageLength,
+ [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec,
+ size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength));
+ const int32_t Decrypted =
+ m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength);
+ if (Decrypted == 0)
+ {
+ Handler(asio::error::make_error_code(asio::error::connection_aborted), 0);
+ return;
+ }
+
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size);
+ memcpy(Dest, PlaintextBuf.data(), BytesToReturn);
+
+ if (static_cast<size_t>(Decrypted) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted);
+ }
+
+ if (BytesToReturn < Size)
+ {
+ DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler));
+ }
+ else
+ {
+ Handler(std::error_code{}, Size);
+ }
+ });
+ });
}
void
-AesComputeTransport::Close()
+AsyncAesComputeTransport::Close()
{
if (!m_IsClosed)
{
- if (m_Inner && m_Inner->IsValid())
+ // Always forward Close() to the inner transport if we have one. Gating on
+ // IsValid() skipped cleanup when the inner transport was partially torn down
+ // (e.g. after a read/write error marked it non-valid but left its socket open),
+ // leaking OS handles. Close implementations are expected to be idempotent.
+ if (m_Inner)
{
m_Inner->Close();
}
diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h
index efcad9835..7846073dc 100644
--- a/src/zenhorde/hordetransportaes.h
+++ b/src/zenhorde/hordetransportaes.h
@@ -6,47 +6,54 @@
#include <cstdint>
#include <memory>
-#include <mutex>
#include <vector>
+namespace asio {
+class io_context;
+}
+
namespace zen::horde {
-/** AES-256-GCM encrypted transport wrapper.
+/** Async AES-256-GCM encrypted transport wrapper.
*
- * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting
- * all incoming data using AES-256-GCM. The nonce is mutated per message using
- * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1].
+ * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming
+ * data using AES-256-GCM. Outgoing nonces follow the NIST SP 800-38D §8.2.1
+ * deterministic construction: a 4-byte fixed field followed by an 8-byte
+ * big-endian monotonic counter. The session is torn down if the counter
+ * would wrap.
*
* Wire format per encrypted message:
* [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)]
*
* Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time).
+ *
+ * Thread safety: all operations must be serialized by the caller (e.g. via a strand).
*/
-class AesComputeTransport final : public ComputeTransport
+class AsyncAesComputeTransport final : public AsyncComputeTransport
{
public:
- AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport);
- ~AesComputeTransport() override;
+ AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext);
+ ~AsyncAesComputeTransport() override;
- bool IsValid() const override;
- size_t Send(const void* Data, size_t Size) override;
- size_t Recv(void* Data, size_t Size) override;
- void MarkComplete() override;
- void Close() override;
+ bool IsValid() const override;
+ void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void Close() override;
private:
- static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size
- static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size
+ void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler);
struct CryptoContext;
- std::unique_ptr<CryptoContext> m_Crypto;
- std::unique_ptr<ComputeTransport> m_Inner;
- std::vector<uint8_t> m_EncryptBuffer;
- std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv
- size_t m_RemainingOffset = 0;
- std::mutex m_Lock;
- bool m_IsClosed = false;
+ std::unique_ptr<CryptoContext> m_Crypto;
+ std::unique_ptr<AsyncComputeTransport> m_Inner;
+ asio::io_context& m_IoContext;
+ std::vector<uint8_t> m_DecryptBuffer;
+ std::vector<uint8_t> m_RemainingData;
+ size_t m_RemainingOffset = 0;
+ bool m_IsClosed = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h
index 201d68b83..87caec019 100644
--- a/src/zenhorde/include/zenhorde/hordeclient.h
+++ b/src/zenhorde/include/zenhorde/hordeclient.h
@@ -45,14 +45,15 @@ struct MachineInfo
uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES)
bool IsWindows = false;
std::string LeaseId;
+ std::string Pool;
std::map<std::string, PortInfo> Ports;
/** Return the address to connect to, accounting for connection mode. */
- const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
+ [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
/** Return the port to connect to, accounting for connection mode and port mapping. */
- uint16_t GetConnectionPort() const
+ [[nodiscard]] uint16_t GetConnectionPort() const
{
if (Mode == ConnectionMode::Relay)
{
@@ -65,7 +66,20 @@ struct MachineInfo
return Port;
}
- bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
+ /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */
+ [[nodiscard]] std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const
+ {
+ if (Mode == ConnectionMode::Relay)
+ {
+ if (auto It = Ports.find("ZenPort"); It != Ports.end())
+ {
+ return {ConnectionAddress, It->second.Port};
+ }
+ }
+ return {Ip, DefaultPort};
+ }
+
+ [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
};
/** Result of cluster auto-resolution via the Horde API. */
@@ -83,31 +97,29 @@ struct ClusterInfo
class HordeClient
{
public:
- explicit HordeClient(const HordeConfig& Config);
+ explicit HordeClient(HordeConfig Config);
~HordeClient();
HordeClient(const HordeClient&) = delete;
HordeClient& operator=(const HordeClient&) = delete;
/** Initialize the underlying HTTP client. Must be called before other methods. */
- bool Initialize();
+ [[nodiscard]] bool Initialize();
/** Build the JSON request body for cluster resolution and machine requests.
* Encodes pool, condition, connection mode, encryption, and port requirements. */
- std::string BuildRequestBody() const;
+ [[nodiscard]] std::string BuildRequestBody() const;
/** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */
- bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
+ [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
/** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}.
* On success, populates OutMachine with connection details and credentials. */
- bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
+ [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
LoggerRef Log() { return m_Log; }
private:
- bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize);
-
HordeConfig m_Config;
std::unique_ptr<zen::HttpClient> m_Http;
LoggerRef m_Log;
diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h
index dd70f9832..3a4dfb386 100644
--- a/src/zenhorde/include/zenhorde/hordeconfig.h
+++ b/src/zenhorde/include/zenhorde/hordeconfig.h
@@ -4,6 +4,10 @@
#include <zenhorde/zenhorde.h>
+#include <zenhttp/httpclient.h>
+
+#include <functional>
+#include <optional>
#include <string>
namespace zen::horde {
@@ -33,20 +37,25 @@ struct HordeConfig
static constexpr const char* ClusterDefault = "default";
static constexpr const char* ClusterAuto = "_auto";
- bool Enabled = false; ///< Whether Horde provisioning is active
- std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
- std::string AuthToken; ///< Authentication token for the Horde API
- std::string Pool; ///< Pool name to request machines from
- std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
- std::string Condition; ///< Agent filter expression for machine selection
- std::string HostAddress; ///< Address that provisioned agents use to connect back to us
- std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
- uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
-
- int MaxCores = 2048;
- bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
- ConnectionMode Mode = ConnectionMode::Direct;
- Encryption EncryptionMode = Encryption::None;
+ bool Enabled = false; ///< Whether Horde provisioning is active
+ std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
+ std::string AuthToken; ///< Authentication token for the Horde API (static fallback)
+
+ /// Optional token provider with automatic refresh (e.g. from OidcToken executable).
+ /// When set, takes priority over the static AuthToken string.
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+ std::string Pool; ///< Pool name to request machines from
+ std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
+ std::string Condition; ///< Agent filter expression for machine selection
+ std::string HostAddress; ///< Address that provisioned agents use to connect back to us
+ std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
+ uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
+
+ int MaxCores = 2048;
+ int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min)
+ bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
+ ConnectionMode Mode = ConnectionMode::Direct;
+ Encryption EncryptionMode = Encryption::None;
/** Validate the configuration. Returns false if the configuration is invalid
* (e.g. Relay mode without AES encryption). */
diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h
index 4e2e63bbd..ea2fd7783 100644
--- a/src/zenhorde/include/zenhorde/hordeprovisioner.h
+++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h
@@ -2,21 +2,32 @@
#pragma once
+#include <zenhorde/hordeclient.h>
#include <zenhorde/hordeconfig.h>
+#include <zencompute/provisionerstate.h>
#include <zencore/logbase.h>
+#include <zencore/thread.h>
#include <atomic>
#include <cstdint>
+#include <deque>
#include <filesystem>
#include <memory>
#include <mutex>
#include <string>
+#include <thread>
+#include <unordered_set>
#include <vector>
+namespace asio {
+class io_context;
+}
+
namespace zen::horde {
class HordeClient;
+class AsyncHordeAgent;
/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */
struct ProvisioningStats
@@ -35,13 +46,12 @@ struct ProvisioningStats
* binary, and executing it remotely. Each provisioned machine runs zenserver
* in compute mode, which announces itself back to the orchestrator.
*
- * Spawns one thread per agent. Each thread handles the full lifecycle:
- * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption ->
- * channel setup -> binary upload -> remote execution -> poll until exit.
+ * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread
+ * pool rather than spawning a dedicated thread per agent.
*
* Thread safety: SetTargetCoreCount and GetStats may be called from any thread.
*/
-class HordeProvisioner
+class HordeProvisioner : public compute::IProvisionerStateProvider
{
public:
/** Construct a provisioner.
@@ -52,38 +62,48 @@ public:
HordeProvisioner(const HordeConfig& Config,
const std::filesystem::path& BinariesPath,
const std::filesystem::path& WorkingDir,
- std::string_view OrchestratorEndpoint);
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession = {},
+ bool CleanStart = false,
+ std::string_view TraceHost = {});
- /** Signals all agent threads to exit and joins them. */
- ~HordeProvisioner();
+ /** Signals all agents to exit and waits for completion. */
+ ~HordeProvisioner() override;
HordeProvisioner(const HordeProvisioner&) = delete;
HordeProvisioner& operator=(const HordeProvisioner&) = delete;
/** Set the target number of cores to provision.
- * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the
- * estimated core count is below the target. Also joins any finished
- * agent threads. */
- void SetTargetCoreCount(uint32_t Count);
+ * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the
+ * estimated core count is below the target. Also removes finished agents. */
+ void SetTargetCoreCount(uint32_t Count) override;
/** Return a snapshot of the current provisioning counters. */
ProvisioningStats GetStats() const;
- uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); }
- uint32_t GetAgentCount() const;
+ // IProvisionerStateProvider
+ std::string_view GetName() const override { return "horde"; }
+ uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); }
+ uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); }
+ uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); }
+ uint32_t GetAgentCount() const override;
+ uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); }
+ compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override;
private:
LoggerRef Log() { return m_Log; }
- struct AgentWrapper;
-
void RequestAgent();
- void ThreadAgent(AgentWrapper& Wrapper);
+ void ProvisionAgent();
+ bool InitializeHordeClient();
HordeConfig m_Config;
std::filesystem::path m_BinariesPath;
std::filesystem::path m_WorkingDir;
std::string m_OrchestratorEndpoint;
+ std::string m_CoordinatorSession;
+ bool m_CleanStart = false;
+ std::string m_TraceHost;
std::unique_ptr<HordeClient> m_HordeClient;
@@ -91,20 +111,54 @@ private:
std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs
bool m_BundlesCreated = false;
- mutable std::mutex m_AgentsLock;
- std::vector<std::unique_ptr<AgentWrapper>> m_Agents;
-
std::atomic<uint64_t> m_LastRequestFailTime{0};
std::atomic<uint32_t> m_TargetCoreCount{0};
std::atomic<uint32_t> m_EstimatedCoreCount{0};
std::atomic<uint32_t> m_ActiveCoreCount{0};
std::atomic<uint32_t> m_AgentsActive{0};
+ std::atomic<uint32_t> m_AgentsDraining{0};
std::atomic<uint32_t> m_AgentsRequesting{0};
std::atomic<bool> m_AskForAgents{true};
+ std::atomic<uint32_t> m_PendingWorkItems{0};
+ Event m_AllWorkDone;
+ /** Manual-reset event set alongside m_AskForAgents=false so pool-thread backoff waits
+ * wake immediately on shutdown instead of polling a 100ms sleep. */
+ Event m_ShutdownEvent;
LoggerRef m_Log;
+ // Async I/O
+ std::unique_ptr<asio::io_context> m_IoContext;
+ std::vector<std::thread> m_IoThreads;
+
+ struct AsyncAgentEntry
+ {
+ std::shared_ptr<AsyncHordeAgent> Agent;
+ std::string RemoteEndpoint;
+ std::string LeaseId;
+ uint16_t CoreCount = 0;
+ bool Draining = false;
+ };
+
+ mutable std::mutex m_AsyncAgentsLock;
+ std::vector<AsyncAgentEntry> m_AsyncAgents;
+
+ /** Worker IDs of agents that completed after draining.
+ * GetAgentStatus() consumes entries when queried, but if no one queries, entries would
+ * otherwise accumulate unbounded across the lifetime of the provisioner. Cap the set
+ * at RecentlyDrainedCapacity by evicting the oldest entry (tracked in an insertion-order
+ * queue) whenever we insert past the limit. */
+ mutable std::unordered_set<std::string> m_RecentlyDrainedWorkerIds;
+ mutable std::deque<std::string> m_RecentlyDrainedOrder;
+ static constexpr size_t RecentlyDrainedCapacity = 256;
+
+ void OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent);
+ void DrainAsyncAgent(AsyncAgentEntry& Entry);
+
+ std::vector<std::string> BuildAgentArgs(const MachineInfo& Machine) const;
+
static constexpr uint32_t EstimatedCoresPerAgent = 32;
+ static constexpr int IoThreadCount = 3;
};
} // namespace zen::horde
diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp
new file mode 100644
index 000000000..151863370
--- /dev/null
+++ b/src/zenhttp/asynchttpclient_test.cpp
@@ -0,0 +1,315 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/asynchttpclient.h>
+#include <zenhttp/httpserver.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/iobuffer.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include "servers/httpasio.h"
+
+# include <atomic>
+# include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+// Reusable test service for async client tests
+
+class AsyncHttpClientTestService : public HttpService
+{
+public:
+ AsyncHttpClientTestService()
+ {
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ IoBuffer Body = HttpReq.ReadPayload();
+ HttpContentType CT = HttpReq.RequestContentType();
+ HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body);
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "echo/method",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Method = ToString(HttpReq.RequestVerb());
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "nocontent",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "json",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}");
+ },
+ HttpVerb::kGet);
+ }
+
+ virtual const char* BaseUri() const override { return "/api/async-test/"; }
+ virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); }
+
+private:
+ HttpRequestRouter m_Router;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct AsyncTestServerFixture
+{
+ AsyncHttpClientTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ Ref<HttpServer> Server;
+ std::thread ServerThread;
+ int Port = -1;
+
+ AsyncTestServerFixture()
+ {
+ Server = CreateHttpAsioServer(AsioConfig{});
+ Port = Server->Initialize(0, TmpDir.Path());
+ ZEN_ASSERT(Port != -1);
+ Server->RegisterService(TestService);
+ ServerThread = std::thread([this]() { Server->Run(false); });
+ }
+
+ ~AsyncTestServerFixture()
+ {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ }
+
+ AsyncHttpClient MakeClient(HttpClientSettings Settings = {}) { return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), Settings); }
+
+ AsyncHttpClient MakeClient(asio::io_context& IoContext, HttpClientSettings Settings = {})
+ {
+ return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), IoContext, Settings);
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+TEST_SUITE_BEGIN("http.asynchttpclient");
+
+TEST_CASE("asynchttpclient.future.verbs")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("GET returns 200 with expected body")
+ {
+ auto Future = Client.Get("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "GET");
+ }
+
+ SUBCASE("POST dispatches correctly")
+ {
+ auto Future = Client.Post("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "POST");
+ }
+
+ SUBCASE("PUT dispatches correctly")
+ {
+ auto Future = Client.Put("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "PUT");
+ }
+
+ SUBCASE("DELETE dispatches correctly")
+ {
+ auto Future = Client.Delete("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "DELETE");
+ }
+
+ SUBCASE("HEAD returns 200 with empty body")
+ {
+ auto Future = Client.Head("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), ""sv);
+ }
+}
+
+TEST_CASE("asynchttpclient.future.get")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("simple GET with text response")
+ {
+ auto Future = Client.Get("/api/async-test/hello");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("GET returning JSON")
+ {
+ auto Future = Client.Get("/api/async-test/json");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "{\"ok\":true}");
+ }
+
+ SUBCASE("GET 204 NoContent")
+ {
+ auto Future = Client.Get("/api/async-test/nocontent");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+}
+
+TEST_CASE("asynchttpclient.future.post.with.payload")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ std::string_view PayloadStr = "async payload data";
+ IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size());
+ Payload.SetContentType(ZenContentType::kText);
+
+ auto Future = Client.Post("/api/async-test/echo", Payload);
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "async payload data");
+}
+
+TEST_CASE("asynchttpclient.future.put.with.payload")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ std::string_view PutStr = "put payload";
+ IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size());
+ Payload.SetContentType(ZenContentType::kText);
+
+ auto Future = Client.Put("/api/async-test/echo", Payload);
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "put payload");
+}
+
+TEST_CASE("asynchttpclient.callback")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ std::promise<HttpClient::Response> Promise;
+ auto Future = Promise.get_future();
+
+ Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); });
+
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+}
+
+TEST_CASE("asynchttpclient.concurrent.requests")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ // Fire multiple requests concurrently
+ auto Future1 = Client.Get("/api/async-test/hello");
+ auto Future2 = Client.Get("/api/async-test/json");
+ auto Future3 = Client.Post("/api/async-test/echo/method");
+ auto Future4 = Client.Delete("/api/async-test/echo/method");
+
+ auto Resp1 = Future1.get();
+ auto Resp2 = Future2.get();
+ auto Resp3 = Future3.get();
+ auto Resp4 = Future4.get();
+
+ CHECK(Resp1.IsSuccess());
+ CHECK_EQ(Resp1.AsText(), "hello world");
+
+ CHECK(Resp2.IsSuccess());
+ CHECK_EQ(Resp2.AsText(), "{\"ok\":true}");
+
+ CHECK(Resp3.IsSuccess());
+ CHECK_EQ(Resp3.AsText(), "POST");
+
+ CHECK(Resp4.IsSuccess());
+ CHECK_EQ(Resp4.AsText(), "DELETE");
+}
+
+TEST_CASE("asynchttpclient.external.io_context")
+{
+ AsyncTestServerFixture Fixture;
+
+ asio::io_context IoContext;
+ auto WorkGuard = asio::make_work_guard(IoContext);
+ std::thread IoThread([&IoContext]() { IoContext.run(); });
+
+ {
+ AsyncHttpClient Client = Fixture.MakeClient(IoContext);
+
+ auto Future = Client.Get("/api/async-test/hello");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ WorkGuard.reset();
+ IoThread.join();
+}
+
+TEST_CASE("asynchttpclient.connection.error")
+{
+ // Connect to a port where nothing is listening
+ AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)});
+
+ auto Future = Client.Get("/should-fail");
+ auto Resp = Future.get();
+
+ CHECK_FALSE(Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ CHECK(Resp.Error->IsConnectionError());
+}
+
+TEST_SUITE_END();
+
+void
+asynchttpclient_test_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp
index 209276621..2fa22f2c2 100644
--- a/src/zenhttp/auth/authmgr.cpp
+++ b/src/zenhttp/auth/authmgr.cpp
@@ -132,7 +132,7 @@ public:
}
}
- RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}));
+ Ref<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}));
if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
{
@@ -232,10 +232,10 @@ public:
private:
struct OpenIdProvider
{
- std::string Name;
- std::string Url;
- std::string ClientId;
- RefPtr<OidcClient> HttpClient;
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ Ref<OidcClient> HttpClient;
};
struct OpenIdToken
@@ -262,7 +262,7 @@ private:
{
ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken");
- RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient;
+ Ref<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient;
if (!Client)
{
return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp
new file mode 100644
index 000000000..ea88fc783
--- /dev/null
+++ b/src/zenhttp/clients/asynchttpclient.cpp
@@ -0,0 +1,1033 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/asynchttpclient.h>
+
+#include "httpclientcurlhelpers.h"
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#include <asio/steady_timer.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// TransferContext: per-transfer state associated with each CURL easy handle
+
+struct TransferContext
+{
+ AsyncHttpCallback Callback;
+ std::string Body;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ CurlWriteCallbackData WriteData;
+ CurlHeaderCallbackData HeaderData;
+ curl_slist* HeaderList = nullptr;
+
+ // For PUT/POST with payload: keep the data alive until transfer completes
+ IoBuffer PayloadBuffer;
+ CurlReadCallbackData ReadData;
+
+ TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback))
+ {
+ WriteData.Body = &Body;
+ HeaderData.Headers = &ResponseHeaders;
+ }
+
+ ~TransferContext()
+ {
+ if (HeaderList)
+ {
+ curl_slist_free_all(HeaderList);
+ }
+ }
+
+ TransferContext(const TransferContext&) = delete;
+ TransferContext& operator=(const TransferContext&) = delete;
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// AsyncHttpClient::Impl
+
+struct AsyncHttpClient::Impl
+{
+ Impl(std::string_view BaseUri, const HttpClientSettings& Settings)
+ : m_BaseUri(BaseUri)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ , m_Strand(asio::make_strand(m_IoContext))
+ , m_Timer(m_Strand)
+ {
+ Init();
+ m_WorkGuard.emplace(m_IoContext.get_executor());
+ m_IoThread = std::thread([this]() {
+ SetCurrentThreadName("async_http");
+ try
+ {
+ m_IoContext.run();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what());
+ }
+ });
+ }
+
+ Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings)
+ : m_BaseUri(BaseUri)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ , m_Strand(asio::make_strand(m_IoContext))
+ , m_Timer(m_Strand)
+ {
+ Init();
+ }
+
+ ~Impl()
+ {
+ // Clean up curl state on the strand where all curl_multi operations
+ // are serialized. Use a promise to block until the cleanup handler
+ // has actually executed - essential for the external io_context case
+ // where we don't own the run loop.
+ std::promise<void> Done;
+ std::future<void> DoneFuture = Done.get_future();
+
+ asio::post(m_Strand, [this, &Done]() {
+ m_ShuttingDown = true;
+ m_Timer.cancel();
+
+ // Release all tracked sockets (don't close - curl owns the fds).
+ for (auto& [Fd, Info] : m_Sockets)
+ {
+ if (Info->Socket.is_open())
+ {
+ Info->Socket.cancel();
+ Info->Socket.release();
+ }
+ }
+ m_Sockets.clear();
+
+ for (auto& [Handle, Ctx] : m_Transfers)
+ {
+ curl_multi_remove_handle(m_Multi, Handle);
+ curl_easy_cleanup(Handle);
+ }
+ m_Transfers.clear();
+
+ for (CURL* Handle : m_HandlePool)
+ {
+ curl_easy_cleanup(Handle);
+ }
+ m_HandlePool.clear();
+
+ Done.set_value();
+ });
+
+ // For owned io_context: release work guard so run() can return after
+ // processing the cleanup handler above.
+ m_WorkGuard.reset();
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ else
+ {
+ // External io_context: wait for the cleanup handler to complete.
+ DoneFuture.wait();
+ }
+
+ if (m_Multi)
+ {
+ curl_multi_cleanup(m_Multi);
+ }
+ }
+
+ LoggerRef Log() { return m_Log; }
+
+ void Init()
+ {
+ m_Multi = curl_multi_init();
+ if (!m_Multi)
+ {
+ throw std::runtime_error("curl_multi_init failed");
+ }
+
+ SetupMultiCallbacks();
+
+ if (m_Settings.SessionId == Oid::Zero)
+ {
+ m_SessionId = std::string(GetSessionIdString());
+ }
+ else
+ {
+ m_SessionId = m_Settings.SessionId.ToString();
+ }
+ }
+
+ // -- Handle pool -----------------------------------------------------
+
+ CURL* AllocHandle()
+ {
+ if (!m_HandlePool.empty())
+ {
+ CURL* Handle = m_HandlePool.back();
+ m_HandlePool.pop_back();
+ curl_easy_reset(Handle);
+ return Handle;
+ }
+ CURL* Handle = curl_easy_init();
+ if (!Handle)
+ {
+ throw std::runtime_error("curl_easy_init failed");
+ }
+ return Handle;
+ }
+
+ void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); }
+
+ // -- Configure a handle with common settings -------------------------
+ // Called only from DoAsync* lambdas running on the strand.
+
+ void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters)
+ {
+ // Build URL
+ ExtendableStringBuilder<256> Url;
+ BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str());
+
+ // Unix domain socket
+ if (!m_Settings.UnixSocketPath.empty())
+ {
+ m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath);
+ curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str());
+ }
+
+ // Timeouts
+ if (m_Settings.ConnectTimeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_Settings.ConnectTimeout.count()));
+ }
+ if (m_Settings.Timeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count()));
+ }
+
+ // HTTP/2
+ if (m_Settings.AssumeHttp2)
+ {
+ curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE);
+ }
+
+ // SSL
+ if (m_Settings.InsecureSsl)
+ {
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L);
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L);
+ }
+ if (!m_Settings.CaBundlePath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str());
+ }
+
+ // Verbose/debug
+ if (m_Settings.Verbose)
+ {
+ curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L);
+ }
+
+ // Thread safety
+ curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L);
+
+ if (m_Settings.ForbidReuseConnection)
+ {
+ curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L);
+ }
+ }
+
+ // -- Access token ----------------------------------------------------
+
+ std::optional<std::string> GetAccessToken()
+ {
+ if (!m_Settings.AccessTokenProvider.has_value())
+ {
+ return {};
+ }
+ {
+ RwLock::SharedLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ }
+ RwLock::ExclusiveLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()();
+ if (!NewToken.IsValid())
+ {
+ ZEN_WARN("AsyncHttpClient: failed to refresh access token, retrying once");
+ NewToken = m_Settings.AccessTokenProvider.value()();
+ }
+ if (NewToken.IsValid())
+ {
+ m_CachedAccessToken = NewToken;
+ return m_CachedAccessToken.GetValue();
+ }
+ ZEN_WARN("AsyncHttpClient: access token provider returned invalid token");
+ return {};
+ }
+
+ // -- Submit a transfer -----------------------------------------------
+
+ void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx)
+ {
+ ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer");
+ // Setup write/header callbacks
+ curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData);
+ curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData);
+
+ m_Transfers[Handle] = std::move(Ctx);
+
+ CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle);
+ if (Mc != CURLM_OK)
+ {
+ auto Stolen = std::move(m_Transfers[Handle]);
+ m_Transfers.erase(Handle);
+ ReleaseHandle(Handle);
+
+ HttpClient::Response ErrorResponse;
+ ErrorResponse.Error =
+ HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError,
+ .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))};
+ asio::post(m_IoContext,
+ [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); });
+ return;
+ }
+ }
+
+ // -- Socket-action integration ---------------------------------------
+ //
+ // curl_multi drives I/O via two callbacks:
+ // - SocketCallback: curl tells us which sockets to watch for read/write
+ // - TimerCallback: curl tells us when to fire a timeout
+ //
+ // On each socket event or timeout we call curl_multi_socket_action(),
+ // then drain completed transfers via curl_multi_info_read().
+
+ // Per-socket state: wraps the native fd in an ASIO socket for async_wait.
+ struct SocketInfo
+ {
+ asio::ip::tcp::socket Socket;
+ int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT
+
+ explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {}
+ };
+
+ // Static thunks registered with curl_multi ----------------------------
+
+ static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr)
+ {
+ ZEN_UNUSED(Easy);
+ auto* Self = static_cast<Impl*>(UserPtr);
+ Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr));
+ return 0;
+ }
+
+ static int CurlTimerCallback(CURLM* Multi, long TimeoutMs, void* UserPtr)
+ {
+ ZEN_UNUSED(Multi);
+ auto* Self = static_cast<Impl*>(UserPtr);
+ Self->OnCurlTimer(TimeoutMs);
+ return 0;
+ }
+
+ void SetupMultiCallbacks()
+ {
+ curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback);
+ curl_multi_setopt(m_Multi, CURLMOPT_SOCKETDATA, this);
+ curl_multi_setopt(m_Multi, CURLMOPT_TIMERFUNCTION, CurlTimerCallback);
+ curl_multi_setopt(m_Multi, CURLMOPT_TIMERDATA, this);
+ }
+
+ // Called by curl when socket watch state changes ---------------------
+
+ void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info)
+ {
+ if (Action == CURL_POLL_REMOVE)
+ {
+ if (Info)
+ {
+ // Cancel pending async_wait ops before releasing the fd.
+ // curl owns the fd, so we must release() rather than close().
+ Info->Socket.cancel();
+ if (Info->Socket.is_open())
+ {
+ Info->Socket.release();
+ }
+ m_Sockets.erase(Fd);
+ }
+ return;
+ }
+
+ if (!Info)
+ {
+ // New socket - wrap the native fd in an ASIO socket.
+ auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext));
+ Info = It->second.get();
+
+ asio::error_code Ec;
+ // Determine protocol from the fd (v4 vs v6). Default to v4.
+ Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec);
+ if (Ec)
+ {
+ // Try v6 as fallback
+ Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec);
+ }
+ if (Ec)
+ {
+ ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message());
+ m_Sockets.erase(Fd);
+ return;
+ }
+
+ curl_multi_assign(m_Multi, Fd, Info);
+ }
+
+ Info->WatchFlags = Action;
+ SetSocketWatch(Fd, Info);
+ }
+
+ void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info)
+ {
+ // Cancel any pending wait before issuing a new one.
+ Info->Socket.cancel();
+
+ if (Info->WatchFlags & CURL_POLL_IN)
+ {
+ Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) {
+ if (Ec || m_ShuttingDown)
+ {
+ return;
+ }
+ OnSocketReady(Fd, CURL_CSELECT_IN);
+ }));
+ }
+
+ if (Info->WatchFlags & CURL_POLL_OUT)
+ {
+ Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) {
+ if (Ec || m_ShuttingDown)
+ {
+ return;
+ }
+ OnSocketReady(Fd, CURL_CSELECT_OUT);
+ }));
+ }
+ }
+
+ void OnSocketReady(curl_socket_t Fd, int CurlAction)
+ {
+ ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady");
+ int StillRunning = 0;
+ curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning);
+ CheckCompleted();
+
+ // Re-arm the watch if the socket is still tracked.
+ auto It = m_Sockets.find(Fd);
+ if (It != m_Sockets.end())
+ {
+ SetSocketWatch(Fd, It->second.get());
+ }
+ }
+
+ // Called by curl when it wants a timeout ------------------------------
+
+ void OnCurlTimer(long TimeoutMs)
+ {
+ m_Timer.cancel();
+
+ if (TimeoutMs < 0)
+ {
+ // curl says "no timeout needed"
+ return;
+ }
+
+ if (TimeoutMs == 0)
+ {
+ // curl wants immediate action - run it directly on the strand.
+ asio::post(m_Strand, [this]() {
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ int StillRunning = 0;
+ curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning);
+ CheckCompleted();
+ });
+ return;
+ }
+
+ m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) {
+ if (Ec || m_ShuttingDown)
+ {
+ return;
+ }
+ ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout");
+ int StillRunning = 0;
+ curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning);
+ CheckCompleted();
+ }));
+ }
+
+ // Drain completed transfers from curl_multi --------------------------
+
+ void CheckCompleted()
+ {
+ int MsgsLeft = 0;
+ CURLMsg* Msg = nullptr;
+ while ((Msg = curl_multi_info_read(m_Multi, &MsgsLeft)) != nullptr)
+ {
+ if (Msg->msg != CURLMSG_DONE)
+ {
+ continue;
+ }
+
+ CURL* Handle = Msg->easy_handle;
+ CURLcode Result = Msg->data.result;
+
+ curl_multi_remove_handle(m_Multi, Handle);
+
+ auto It = m_Transfers.find(Handle);
+ if (It == m_Transfers.end())
+ {
+ ReleaseHandle(Handle);
+ continue;
+ }
+
+ std::unique_ptr<TransferContext> Ctx = std::move(It->second);
+ m_Transfers.erase(It);
+
+ CompleteTransfer(Handle, Result, std::move(Ctx));
+ }
+ }
+
+ void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx)
+ {
+ ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer");
+ // Extract result info
+ long StatusCode = 0;
+ curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode);
+
+ double Elapsed = 0;
+ curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed);
+
+ curl_off_t UpBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes);
+
+ curl_off_t DownBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
+
+ ReleaseHandle(Handle);
+
+ // Build response
+ HttpClient::Response Response;
+ Response.StatusCode = HttpResponseCode(StatusCode);
+ Response.UploadedBytes = static_cast<int64_t>(UpBytes);
+ Response.DownloadedBytes = static_cast<int64_t>(DownBytes);
+ Response.ElapsedSeconds = Elapsed;
+ Response.Header = BuildHeaderMap(Ctx->ResponseHeaders);
+
+ if (CurlResult != CURLE_OK)
+ {
+ const char* ErrorMsg = curl_easy_strerror(CurlResult);
+
+ if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK)
+ {
+ ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg);
+ }
+
+ if (!Ctx->Body.empty())
+ {
+ Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size());
+ }
+
+ Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)};
+ }
+ else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty())
+ {
+ // No payload
+ }
+ else
+ {
+ IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size());
+ ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders);
+
+ const HttpResponseCode Code = HttpResponseCode(StatusCode);
+ if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound)
+ {
+ ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri);
+ }
+
+ Response.ResponsePayload = std::move(PayloadBuffer);
+ }
+
+ // Dispatch the user callback off the strand so a slow callback
+ // cannot starve the curl_multi poll loop.
+ asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable {
+ try
+ {
+ Cb(std::move(Response));
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_SCOPED_LOG(LogRef);
+ ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what());
+ }
+ });
+ }
+
+ // -- Async verb implementations --------------------------------------
+
+ void DoAsyncGet(std::string Url,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader,
+ HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader),
+ Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Get");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader)
+ {
+ asio::post(m_Strand,
+ [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Head");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, {});
+ curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader)
+ {
+ asio::post(m_Strand,
+ [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Delete");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, {});
+ curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE");
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPost(std::string Url,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader,
+ HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader),
+ Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Post");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_POST, 1L);
+ curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPostWithPayload(std::string Url,
+ IoBuffer Payload,
+ ZenContentType ContentType,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Payload = std::move(Payload),
+ ContentType,
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, {});
+ curl_easy_setopt(Handle, CURLOPT_POST, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->PayloadBuffer = std::move(Payload);
+ Ctx->HeaderList =
+ BuildHeaderList(AdditionalHeader,
+ m_SessionId,
+ GetAccessToken(),
+ {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))});
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ // Set up read callback for payload data
+ Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData());
+ Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize();
+ Ctx->ReadData.Offset = 0;
+
+ curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize()));
+ curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPutWithPayload(std::string Url,
+ IoBuffer Payload,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader,
+ HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Payload = std::move(Payload),
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader),
+ Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Put");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->PayloadBuffer = std::move(Payload);
+ Ctx->HeaderList = BuildHeaderList(
+ AdditionalHeader,
+ m_SessionId,
+ GetAccessToken(),
+ {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))});
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData());
+ Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize();
+ Ctx->ReadData.Offset = 0;
+
+ curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize()));
+ curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Put");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+
+ HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}};
+ Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ // -- Members ---------------------------------------------------------
+
+ std::string m_BaseUri;
+ HttpClientSettings m_Settings;
+ LoggerRef m_Log;
+ std::string m_SessionId;
+ std::string m_UnixSocketPathUtf8;
+
+ // io_context and strand - all curl_multi operations are serialized on the
+ // strand, making this safe even when the io_context has multiple threads.
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ asio::strand<asio::io_context::executor_type> m_Strand;
+ std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // curl_multi and socket-action state
+ CURLM* m_Multi = nullptr;
+ std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers;
+ std::vector<CURL*> m_HandlePool;
+ std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets;
+ asio::steady_timer m_Timer;
+ bool m_ShuttingDown = false;
+
+ // Access token cache
+ RwLock m_AccessTokenLock;
+ HttpClientAccessToken m_CachedAccessToken;
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// AsyncHttpClient public API
+
+AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(BaseUri, Settings))
+{
+}
+
+AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings))
+{
+}
+
+AsyncHttpClient::~AsyncHttpClient() = default;
+
+// -- Callback-based API --------------------------------------------------
+
+void
+AsyncHttpClient::AsyncGet(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters);
+}
+
+void
+AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncPost(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters);
+}
+
+void
+AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncPost(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncPut(std::string_view Url,
+ const IoBuffer& Payload,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters);
+}
+
+void
+AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters);
+}
+
+// -- Future-based API ----------------------------------------------------
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncGet(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader,
+ Parameters);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncHead(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncDelete(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPost(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader,
+ Parameters);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPost(
+ Url,
+ Payload,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPost(
+ Url,
+ Payload,
+ ContentType,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPut(
+ Url,
+ Payload,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader,
+ Parameters);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPut(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ Parameters);
+ return Future;
+}
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
index d150b44c6..3be7337c1 100644
--- a/src/zenhttp/clients/httpclientcurl.cpp
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpclientcurl.h"
+#include "httpclientcurlhelpers.h"
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
@@ -29,153 +30,7 @@ static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0};
//////////////////////////////////////////////////////////////////////////
-static HttpClientErrorCode
-MapCurlError(CURLcode Code)
-{
- switch (Code)
- {
- case CURLE_OK:
- return HttpClientErrorCode::kOK;
- case CURLE_COULDNT_CONNECT:
- return HttpClientErrorCode::kConnectionFailure;
- case CURLE_COULDNT_RESOLVE_HOST:
- return HttpClientErrorCode::kHostResolutionFailure;
- case CURLE_COULDNT_RESOLVE_PROXY:
- return HttpClientErrorCode::kProxyResolutionFailure;
- case CURLE_RECV_ERROR:
- return HttpClientErrorCode::kNetworkReceiveError;
- case CURLE_SEND_ERROR:
- return HttpClientErrorCode::kNetworkSendFailure;
- case CURLE_OPERATION_TIMEDOUT:
- return HttpClientErrorCode::kOperationTimedOut;
- case CURLE_SSL_CONNECT_ERROR:
- return HttpClientErrorCode::kSSLConnectError;
- case CURLE_SSL_CERTPROBLEM:
- return HttpClientErrorCode::kSSLCertificateError;
- case CURLE_PEER_FAILED_VERIFICATION:
- return HttpClientErrorCode::kSSLCACertError;
- case CURLE_SSL_CIPHER:
- case CURLE_SSL_ENGINE_NOTFOUND:
- case CURLE_SSL_ENGINE_SETFAILED:
- return HttpClientErrorCode::kGenericSSLError;
- case CURLE_ABORTED_BY_CALLBACK:
- return HttpClientErrorCode::kRequestCancelled;
- default:
- return HttpClientErrorCode::kOtherError;
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-//
-// Curl callback helpers
-
-struct WriteCallbackData
-{
- std::string* Body = nullptr;
- std::function<bool()>* CheckIfAbortFunction = nullptr;
-};
-
-static size_t
-CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
-{
- auto* Data = static_cast<WriteCallbackData*>(UserData);
- size_t TotalBytes = Size * Nmemb;
-
- if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
- {
- return 0; // Signal abort to curl
- }
-
- Data->Body->append(Ptr, TotalBytes);
- return TotalBytes;
-}
-
-struct HeaderCallbackData
-{
- std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
-};
-
-// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
-// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
-static std::optional<std::pair<std::string_view, std::string_view>>
-ParseHeaderLine(std::string_view Line)
-{
- while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
- {
- Line.remove_suffix(1);
- }
-
- if (Line.empty())
- {
- return std::nullopt;
- }
-
- size_t ColonPos = Line.find(':');
- if (ColonPos == std::string_view::npos)
- {
- return std::nullopt;
- }
-
- std::string_view Key = Line.substr(0, ColonPos);
- std::string_view Value = Line.substr(ColonPos + 1);
-
- while (!Key.empty() && Key.back() == ' ')
- {
- Key.remove_suffix(1);
- }
- while (!Value.empty() && Value.front() == ' ')
- {
- Value.remove_prefix(1);
- }
-
- return std::pair{Key, Value};
-}
-
-static size_t
-CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
-{
- auto* Data = static_cast<HeaderCallbackData*>(UserData);
- size_t TotalBytes = Size * Nmemb;
-
- if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
- {
- auto& [Key, Value] = *Header;
- Data->Headers->emplace_back(std::string(Key), std::string(Value));
- }
-
- return TotalBytes;
-}
-
-struct ReadCallbackData
-{
- const uint8_t* DataPtr = nullptr;
- size_t DataSize = 0;
- size_t Offset = 0;
- std::function<bool()>* CheckIfAbortFunction = nullptr;
-};
-
-static size_t
-CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
-{
- auto* Data = static_cast<ReadCallbackData*>(UserData);
- size_t MaxRead = Size * Nmemb;
-
- if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
- {
- return CURL_READFUNC_ABORT;
- }
-
- size_t Remaining = Data->DataSize - Data->Offset;
- size_t ToRead = std::min(MaxRead, Remaining);
-
- if (ToRead > 0)
- {
- memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
- Data->Offset += ToRead;
- }
-
- return ToRead;
-}
+// Curl callback helpers and shared utilities are in httpclientcurlhelpers.h
struct StreamReadCallbackData
{
@@ -233,7 +88,7 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi
{
ZEN_UNUSED(Handle);
LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr);
- auto Log = [&]() -> LoggerRef { return LogRef; };
+ ZEN_SCOPED_LOG(LogRef);
std::string_view DataView(Data, Size);
@@ -281,120 +136,6 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi
//////////////////////////////////////////////////////////////////////////
-static std::pair<std::string, std::string>
-HeaderContentType(ZenContentType ContentType)
-{
- return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
-}
-
-static curl_slist*
-BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
- std::string_view SessionId,
- const std::optional<std::string>& AccessToken,
- const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
-{
- curl_slist* Headers = nullptr;
-
- for (const auto& [Key, Value] : *AdditionalHeader)
- {
- ExtendableStringBuilder<64> HeaderLine;
- HeaderLine << Key << ": " << Value;
- Headers = curl_slist_append(Headers, HeaderLine.c_str());
- }
-
- if (!SessionId.empty())
- {
- ExtendableStringBuilder<64> SessionHeader;
- SessionHeader << "UE-Session: " << SessionId;
- Headers = curl_slist_append(Headers, SessionHeader.c_str());
- }
-
- if (AccessToken.has_value())
- {
- ExtendableStringBuilder<128> AuthHeader;
- AuthHeader << "Authorization: " << AccessToken.value();
- Headers = curl_slist_append(Headers, AuthHeader.c_str());
- }
-
- for (const auto& [Key, Value] : ExtraHeaders)
- {
- ExtendableStringBuilder<128> HeaderLine;
- HeaderLine << Key << ": " << Value;
- Headers = curl_slist_append(Headers, HeaderLine.c_str());
- }
-
- return Headers;
-}
-
-static HttpClient::KeyValueMap
-BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
-{
- HttpClient::KeyValueMap HeaderMap;
- for (const auto& [Key, Value] : Headers)
- {
- HeaderMap->insert_or_assign(Key, Value);
- }
- return HeaderMap;
-}
-
-// Scans response headers for Content-Type and applies it to the buffer.
-static void
-ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
-{
- for (const auto& [Key, Value] : Headers)
- {
- if (StrCaseCompare(Key, "Content-Type") == 0)
- {
- Buffer.SetContentType(ParseContentType(Value));
- break;
- }
- }
-}
-
-static void
-AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
-{
- static constexpr char HexDigits[] = "0123456789ABCDEF";
- static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");
-
- for (char C : Input)
- {
- if (Unreserved.Contains(C))
- {
- Out.Append(C);
- }
- else
- {
- uint8_t Byte = static_cast<uint8_t>(C);
- char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
- Out.Append(std::string_view(Encoded, 3));
- }
- }
-}
-
-static void
-BuildUrlWithParameters(StringBuilderBase& Url,
- std::string_view BaseUrl,
- std::string_view ResourcePath,
- const HttpClient::KeyValueMap& Parameters)
-{
- Url.Append(BaseUrl);
- Url.Append(ResourcePath);
-
- if (!Parameters->empty())
- {
- char Separator = '?';
- for (const auto& [Key, Value] : *Parameters)
- {
- Url.Append(Separator);
- AppendUrlEncoded(Url, Key);
- Url.Append('=');
- AppendUrlEncoded(Url, Value);
- Separator = '&';
- }
- }
-}
-
//////////////////////////////////////////////////////////////////////////
CurlHttpClient::CurlHttpClient(std::string_view BaseUri,
@@ -440,9 +181,9 @@ CurlHttpClient::CurlResult
CurlHttpClient::Session::PerformWithResponseCallbacks()
{
std::string Body;
- WriteCallbackData WriteData{.Body = &Body,
+ CurlWriteCallbackData WriteData{.Body = &Body,
.CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr};
- HeaderCallbackData HdrData{};
+ CurlHeaderCallbackData HdrData{};
std::vector<std::pair<std::string, std::string>> ResponseHeaders;
HdrData.Headers = &ResponseHeaders;
@@ -487,6 +228,13 @@ CurlHttpClient::Session::Perform()
curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+ char* EffectiveUrl = nullptr;
+ curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl);
+ if (EffectiveUrl)
+ {
+ Result.Url = EffectiveUrl;
+ }
+
return Result;
}
@@ -553,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId,
if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
{
- ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'",
SessionId,
+ Result.Url,
static_cast<int>(Result.ErrorCode),
Result.ErrorMessage);
}
@@ -702,6 +451,7 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result)
case CURLE_RECV_ERROR:
case CURLE_SEND_ERROR:
case CURLE_OPERATION_TIMEDOUT:
+ case CURLE_PARTIAL_FILE:
return true;
default:
return false;
@@ -748,10 +498,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult
{
if (Result.ErrorCode != CURLE_OK)
{
- ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}",
+ ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}",
SessionId,
static_cast<int>(MapCurlError(Result.ErrorCode)),
Result.ErrorMessage,
+ static_cast<int>(Result.ErrorCode),
Attempt,
m_ConnectionSettings.RetryCount + 1);
}
@@ -998,9 +749,9 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu
curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
- ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
- .DataSize = Payload.GetSize(),
- .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
@@ -1213,7 +964,7 @@ CurlHttpClient::Post(std::string_view Url,
std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
if (Ec)
{
- auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_SCOPED_LOG(Data->Log);
ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}",
Data->TempFolderPath->string(),
Ec.message());
@@ -1266,7 +1017,7 @@ CurlHttpClient::Post(std::string_view Url,
std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
if (Ec)
{
- auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_SCOPED_LOG(Data->Log);
ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}",
Data->TempFolderPath->string(),
Ec.message());
@@ -1367,9 +1118,9 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV
return Sess.PerformWithResponseCallbacks();
}
- ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
- .DataSize = Payload.GetSize(),
- .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
@@ -1532,7 +1283,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp
std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
if (Ec)
{
- auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_SCOPED_LOG(Data->Log);
ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
Data->TempFolderPath->string(),
Ec.message());
@@ -1618,7 +1369,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp
std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
if (Ec)
{
- auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_SCOPED_LOG(Data->Log);
ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
Data->TempFolderPath->string(),
Ec.message());
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
index bdeb46633..ea9193e65 100644
--- a/src/zenhttp/clients/httpclientcurl.h
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -73,6 +73,7 @@ private:
int64_t DownloadedBytes = 0;
CURLcode ErrorCode = CURLE_OK;
std::string ErrorMessage;
+ std::string Url;
};
struct Session
diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h
new file mode 100644
index 000000000..cb5f5d9a9
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurlhelpers.h
@@ -0,0 +1,298 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+// Shared helpers for curl-based HTTP client implementations (sync and async).
+// This is an internal header, not part of the public API.
+
+#include <zencore/string.h>
+
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <curl/curl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Error mapping
+
+inline HttpClientErrorCode
+MapCurlError(CURLcode Code)
+{
+ switch (Code)
+ {
+ case CURLE_OK:
+ return HttpClientErrorCode::kOK;
+ case CURLE_COULDNT_CONNECT:
+ return HttpClientErrorCode::kConnectionFailure;
+ case CURLE_COULDNT_RESOLVE_HOST:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case CURLE_COULDNT_RESOLVE_PROXY:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case CURLE_RECV_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case CURLE_SEND_ERROR:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case CURLE_OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case CURLE_SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case CURLE_SSL_CERTPROBLEM:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case CURLE_PEER_FAILED_VERIFICATION:
+ return HttpClientErrorCode::kSSLCACertError;
+ case CURLE_SSL_CIPHER:
+ case CURLE_SSL_ENGINE_NOTFOUND:
+ case CURLE_SSL_ENGINE_SETFAILED:
+ return HttpClientErrorCode::kGenericSSLError;
+ case CURLE_ABORTED_BY_CALLBACK:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Curl callback data structures and callbacks
+
+struct CurlWriteCallbackData
+{
+ std::string* Body = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+inline size_t
+CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<CurlWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0; // Signal abort to curl
+ }
+
+ Data->Body->append(Ptr, TotalBytes);
+ return TotalBytes;
+}
+
+struct CurlHeaderCallbackData
+{
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+};
+
+// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
+// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
+inline std::optional<std::pair<std::string_view, std::string_view>>
+ParseHeaderLine(std::string_view Line)
+{
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return std::nullopt;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos == std::string_view::npos)
+ {
+ return std::nullopt;
+ }
+
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ return std::pair{Key, Value};
+}
+
+inline size_t
+CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<CurlHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [Key, Value] = *Header;
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+}
+
+struct CurlReadCallbackData
+{
+ const uint8_t* DataPtr = nullptr;
+ size_t DataSize = 0;
+ size_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+inline size_t
+CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<CurlReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->DataSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// URL and header construction
+
+inline void
+AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
+{
+ static constexpr char HexDigits[] = "0123456789ABCDEF";
+ static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");
+
+ for (char C : Input)
+ {
+ if (Unreserved.Contains(C))
+ {
+ Out.Append(C);
+ }
+ else
+ {
+ uint8_t Byte = static_cast<uint8_t>(C);
+ char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
+ Out.Append(std::string_view(Encoded, 3));
+ }
+ }
+}
+
+inline void
+BuildUrlWithParameters(StringBuilderBase& Url,
+ std::string_view BaseUrl,
+ std::string_view ResourcePath,
+ const HttpClient::KeyValueMap& Parameters)
+{
+ Url.Append(BaseUrl);
+ Url.Append(ResourcePath);
+
+ if (!Parameters->empty())
+ {
+ char Separator = '?';
+ for (const auto& [Key, Value] : *Parameters)
+ {
+ Url.Append(Separator);
+ AppendUrlEncoded(Url, Key);
+ Url.Append('=');
+ AppendUrlEncoded(Url, Value);
+ Separator = '&';
+ }
+ }
+}
+
+inline std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+inline curl_slist*
+BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
+ std::string_view SessionId,
+ const std::optional<std::string>& AccessToken,
+ const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
+{
+ curl_slist* Headers = nullptr;
+
+ for (const auto& [Key, Value] : *AdditionalHeader)
+ {
+ ExtendableStringBuilder<64> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ if (!SessionId.empty())
+ {
+ ExtendableStringBuilder<64> SessionHeader;
+ SessionHeader << "UE-Session: " << SessionId;
+ Headers = curl_slist_append(Headers, SessionHeader.c_str());
+ }
+
+ if (AccessToken.has_value())
+ {
+ ExtendableStringBuilder<128> AuthHeader;
+ AuthHeader << "Authorization: " << AccessToken.value();
+ Headers = curl_slist_append(Headers, AuthHeader.c_str());
+ }
+
+ bool HasContentTypeOverride = AdditionalHeader->contains("Content-Type");
+ for (const auto& [Key, Value] : ExtraHeaders)
+ {
+ if (HasContentTypeOverride && Key == "Content-Type")
+ {
+ continue;
+ }
+ ExtendableStringBuilder<128> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ return Headers;
+}
+
+inline HttpClient::KeyValueMap
+BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+ return HeaderMap;
+}
+
+// Scans response headers for Content-Type and applies it to the buffer.
+inline void
+ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ for (const auto& [Key, Value] : Headers)
+ {
+ if (StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ Buffer.SetContentType(ParseContentType(Value));
+ break;
+ }
+ }
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index ace7a3c7f..3da8a9220 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -520,7 +520,7 @@ MeasureLatency(HttpClient& Client, std::string_view Url)
ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url));
// Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable.
- // Bail out immediately — retrying will just burn the connect timeout each time.
+ // Bail out immediately - retrying will just burn the connect timeout each time.
if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError())
{
break;
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
index af653cbb2..deaeca2a8 100644
--- a/src/zenhttp/httpclient_test.cpp
+++ b/src/zenhttp/httpclient_test.cpp
@@ -194,7 +194,7 @@ public:
"slow",
[](HttpRouterRequest& Req) {
Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
- Sleep(2000);
+ Sleep(100);
Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response");
});
},
@@ -414,6 +414,17 @@ TEST_CASE("httpclient.post")
CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}");
}
+ SUBCASE("POST with content type override via additional header")
+ {
+ const char* Payload = "test payload";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON, {{"Content-Type", "text/plain"}});
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "test payload");
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText);
+ }
+
SUBCASE("POST with CbObject payload round-trip")
{
CbObjectWriter Writer;
@@ -750,7 +761,9 @@ TEST_CASE("httpclient.error-handling")
{
SUBCASE("Connection refused")
{
- HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClientSettings Settings;
+ Settings.ConnectTimeout = std::chrono::milliseconds(200);
+ HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {});
HttpClient::Response Resp = Client.Get("/api/test/hello");
CHECK(!Resp.IsSuccess());
CHECK(Resp.Error.has_value());
@@ -760,7 +773,7 @@ TEST_CASE("httpclient.error-handling")
{
TestServerFixture Fixture;
HttpClientSettings Settings;
- Settings.Timeout = std::chrono::milliseconds(500);
+ Settings.Timeout = std::chrono::milliseconds(50);
HttpClient Client = Fixture.MakeClient(Settings);
HttpClient::Response Resp = Client.Get("/api/test/slow");
@@ -970,7 +983,9 @@ TEST_CASE("httpclient.measurelatency")
SUBCASE("Failed measurement against unreachable port")
{
- HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClientSettings Settings;
+ Settings.ConnectTimeout = std::chrono::milliseconds(200);
+ HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {});
LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
CHECK(!Result.Success);
CHECK(!Result.FailureReason.empty());
@@ -1144,7 +1159,7 @@ struct FaultTcpServer
~FaultTcpServer()
{
// io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this
- // thread — ASIO I/O objects are not safe for concurrent access and the io_context
+ // thread - ASIO I/O objects are not safe for concurrent access and the io_context
// thread may be touching the acceptor in StartAccept().
m_IoContext.stop();
if (m_Thread.joinable())
@@ -1498,7 +1513,7 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
std::atomic<bool> StallActive{true};
FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
DrainHttpRequest(Socket);
- // Stop reading body — TCP window will fill and client send will stall
+ // Stop reading body - TCP window will fill and client send will stall
while (StallActive.load())
{
std::this_thread::sleep_for(std::chrono::milliseconds(50));
@@ -1735,21 +1750,21 @@ TEST_CASE("httpclient.uri_decoding")
TestServerFixture Fixture;
HttpClient Client = Fixture.MakeClient();
- // URI without encoding — should pass through unchanged
+ // URI without encoding - should pass through unchanged
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt");
}
- // Percent-encoded space — server should see decoded path
+ // Percent-encoded space - server should see decoded path
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt");
}
- // Percent-encoded slash (%2F) — should be decoded to /
+ // Percent-encoded slash (%2F) - should be decoded to /
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt");
REQUIRE(Resp.IsSuccess());
@@ -1763,21 +1778,21 @@ TEST_CASE("httpclient.uri_decoding")
CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt");
}
- // No capture — echo/uri route returns just RelativeUri
+ // No capture - echo/uri route returns just RelativeUri
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "echo/uri");
}
- // Literal percent that is not an escape (%ZZ) — should be kept as-is
+ // Literal percent that is not an escape (%ZZ) - should be kept as-is
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt");
}
- // Query params — raw values are returned as-is from GetQueryParams
+ // Query params - raw values are returned as-is from GetQueryParams
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test");
REQUIRE(Resp.IsSuccess());
@@ -1788,7 +1803,7 @@ TEST_CASE("httpclient.uri_decoding")
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3");
REQUIRE(Resp.IsSuccess());
- // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly
+ // GetQueryParams returns raw (still-encoded) values - callers must Decode() explicitly
CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3");
}
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
index c42841922..26a7298b3 100644
--- a/src/zenhttp/httpclientauth.cpp
+++ b/src/zenhttp/httpclientauth.cpp
@@ -50,8 +50,6 @@ namespace zen { namespace httpclientauth {
IoBuffer Payload{IoBuffer::Wrap, Body.data(), Body.size()};
- // TODO: ensure this gets the right Content-Type passed along
-
HttpClient::Response Response = Http.Post("", Payload, {{"Content-Type", "application/x-www-form-urlencoded"}});
if (!Response || Response.StatusCode != HttpResponseCode::OK)
@@ -94,7 +92,8 @@ namespace zen { namespace httpclientauth {
std::string_view CloudHost,
bool Unattended,
bool Quiet,
- bool Hidden)
+ bool Hidden,
+ bool IsHordeUrl)
{
Stopwatch Timer;
@@ -117,8 +116,9 @@ namespace zen { namespace httpclientauth {
}
});
- const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}",
+ const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}",
OidcExecutablePath,
+ IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl",
CloudHost,
AuthTokenPath,
Unattended ? "true"sv : "false"sv);
@@ -193,7 +193,7 @@ namespace zen { namespace httpclientauth {
}
else
{
- ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode);
+ ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode);
}
return HttpClientAccessToken{};
}
@@ -202,9 +202,10 @@ namespace zen { namespace httpclientauth {
std::string_view CloudHost,
bool Quiet,
bool Unattended,
- bool Hidden)
+ bool Hidden,
+ bool IsHordeUrl)
{
- HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl);
if (InitialToken.IsValid())
{
return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath),
@@ -212,12 +213,13 @@ namespace zen { namespace httpclientauth {
Token = InitialToken,
Quiet,
Unattended,
- Hidden]() mutable {
+ Hidden,
+ IsHordeUrl]() mutable {
if (!Token.NeedsRefresh())
{
return std::move(Token);
}
- return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl);
};
}
return {};
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index e05c9815f..03117ee6c 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -266,10 +266,10 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
return false;
}
- const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim));
- const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1));
+ const auto Start = ParseInt<uint64_t>(Token.substr(0, Delim));
+ const auto End = ParseInt<uint64_t>(Token.substr(Delim + 1));
- if (Start.has_value() && End.has_value() && End.value() > Start.value())
+ if (Start.has_value() && End.has_value() && End.value() >= Start.value())
{
Ranges.push_back({.Start = Start.value(), .End = End.value()});
}
@@ -286,6 +286,45 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
return Count != Ranges.size();
}
+MultipartByteRangesResult
+BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges)
+{
+ Oid::String_t BoundaryStr;
+ Oid::NewOid().ToString(BoundaryStr);
+ std::string_view Boundary(BoundaryStr, Oid::StringLength);
+
+ const uint64_t TotalSize = Data.GetSize();
+
+ std::vector<IoBuffer> Parts;
+ Parts.reserve(Ranges.size() * 2 + 1);
+
+ for (const HttpRange& Range : Ranges)
+ {
+ uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1;
+ if (RangeEnd >= TotalSize || Range.Start > RangeEnd)
+ {
+ return {};
+ }
+
+ uint64_t RangeSize = 1 + (RangeEnd - Range.Start);
+
+ std::string PartHeader = fmt::format("\r\n--{}\r\nContent-Type: application/octet-stream\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
+ Boundary,
+ Range.Start,
+ RangeEnd,
+ TotalSize);
+ Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(PartHeader.data(), PartHeader.size()));
+
+ IoBuffer RangeData(Data, Range.Start, RangeSize);
+ Parts.push_back(RangeData);
+ }
+
+ std::string ClosingBoundary = fmt::format("\r\n--{}--", Boundary);
+ Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(ClosingBoundary.data(), ClosingBoundary.size()));
+
+ return {.Parts = std::move(Parts), .ContentType = fmt::format("multipart/byteranges; boundary={}", Boundary)};
+}
+
//////////////////////////////////////////////////////////////////////////
const std::string_view
@@ -479,6 +518,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
return Ref<IHttpPackageHandler>();
}
+bool
+HttpService::AcceptsLocalFileReferences() const
+{
+ return false;
+}
+
+const ILocalRefPolicy*
+HttpService::GetLocalRefPolicy() const
+{
+ return nullptr;
+}
+
//////////////////////////////////////////////////////////////////////////
HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service)
@@ -552,6 +603,56 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType
}
void
+HttpServerRequest::WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges)
+{
+ if (Ranges.empty())
+ {
+ WriteResponse(HttpResponseCode::OK, ContentType, IoBuffer(Data));
+ return;
+ }
+
+ if (Ranges.size() == 1)
+ {
+ const HttpRange& Range = Ranges[0];
+ const uint64_t TotalSize = Data.GetSize();
+ // ~uint64_t(0) is the sentinel meaning "end of file" (suffix range).
+ const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1;
+
+ if (RangeEnd >= TotalSize || Range.Start > RangeEnd)
+ {
+ m_ContentRangeHeader = fmt::format("bytes */{}", TotalSize);
+ WriteResponse(HttpResponseCode::RangeNotSatisfiable);
+ return;
+ }
+
+ const uint64_t RangeSize = 1 + (RangeEnd - Range.Start);
+ IoBuffer RangeBuf(Data, Range.Start, RangeSize);
+
+ m_ContentRangeHeader = fmt::format("bytes {}-{}/{}", Range.Start, RangeEnd, TotalSize);
+ WriteResponse(HttpResponseCode::PartialContent, ContentType, std::move(RangeBuf));
+ return;
+ }
+
+ // Multi-range
+ MultipartByteRangesResult MultipartResult = BuildMultipartByteRanges(Data, Ranges);
+ if (MultipartResult.Parts.empty())
+ {
+ m_ContentRangeHeader = fmt::format("bytes */{}", Data.GetSize());
+ WriteResponse(HttpResponseCode::RangeNotSatisfiable);
+ return;
+ }
+ WriteResponse(HttpResponseCode::PartialContent, std::move(MultipartResult.ContentType), std::span<IoBuffer>(MultipartResult.Parts));
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(ParseContentType(CustomContentType) == HttpContentType::kUnknownContentType);
+ m_ContentTypeOverride = CustomContentType;
+ WriteResponse(ResponseCode, HttpContentType::kBinary, Blobs);
+}
+
+void
HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload)
{
std::span<const SharedBuffer> Segments = Payload.GetSegments();
@@ -705,7 +806,10 @@ HttpServerRequest::ReadPayloadPackage()
{
if (IoBuffer Payload = ReadPayload())
{
- return ParsePackageMessage(std::move(Payload));
+ ParseFlags Flags =
+ (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault;
+ const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr;
+ return ParsePackageMessage(std::move(Payload), {}, Flags, Policy);
}
return {};
@@ -816,7 +920,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
// Strip the separator slash left over after the service prefix is removed.
// When a service has BaseUri "/foo", the prefix length is set to len("/foo") = 4.
- // Stripping 4 chars from "/foo/bar" yields "/bar" — the path separator becomes
+ // Stripping 4 chars from "/foo/bar" yields "/bar" - the path separator becomes
// the first character of the relative URI. Remove it so patterns like "bar" or
// "{id}" match without needing to account for the leading slash.
if (!Uri.empty() && Uri.front() == '/')
@@ -1273,7 +1377,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
return PackageHandlerRef->CreateTarget(Cid, Size);
};
- CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer);
+ ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences())
+ ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ const ILocalRefPolicy* PkgPolicy =
+ EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr;
+ CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy);
PackageHandlerRef->OnRequestComplete();
}
@@ -1512,7 +1621,7 @@ TEST_CASE("http.common")
},
HttpVerb::kGet);
- // Single-segment literal with leading slash — simulates real server RelativeUri
+ // Single-segment literal with leading slash - simulates real server RelativeUri
{
Reset();
TestHttpServerRequest req{Service, "/activity_counters"sv};
@@ -1532,7 +1641,7 @@ TEST_CASE("http.common")
CHECK_EQ(Captures[0], "hello"sv);
}
- // Two-segment route with leading slash — first literal segment
+ // Two-segment route with leading slash - first literal segment
{
Reset();
TestHttpServerRequest req{Service, "/prefix/world"sv};
diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h
new file mode 100644
index 000000000..cb41626b9
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/asynchttpclient.h
@@ -0,0 +1,123 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zenhttp/httpclient.h>
+
+#include <functional>
+#include <future>
+#include <memory>
+
+namespace asio {
+class io_context;
+}
+
+namespace zen {
+
+/// Completion callback for async HTTP operations.
+using AsyncHttpCallback = std::function<void(HttpClient::Response)>;
+
+/** Asynchronous HTTP client backed by curl_multi and ASIO.
+ *
+ * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process
+ * transfers without blocking the caller. All curl_multi operations are
+ * serialized on an internal strand; callers may issue requests from any
+ * thread, and the io_context may have multiple threads.
+ *
+ * Two construction modes:
+ * - Owned io_context: creates an internal thread (self-contained).
+ * - External io_context: caller runs the event loop.
+ *
+ * Completion callbacks are dispatched on the io_context (not the internal
+ * strand), so a slow callback will not block the curl poll loop. Future-
+ * based wrappers (Get, Post, ...) return a std::future<Response> for
+ * callers that prefer blocking on a result.
+ */
+class AsyncHttpClient
+{
+public:
+ using Response = HttpClient::Response;
+ using KeyValueMap = HttpClient::KeyValueMap;
+
+ /// Construct with an internally-owned io_context and thread.
+ explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {});
+
+ /// Construct with an externally-managed io_context. The io_context must
+ /// outlive this client and must be running (via run()) on at least one thread.
+ AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {});
+
+ ~AsyncHttpClient();
+
+ AsyncHttpClient(const AsyncHttpClient&) = delete;
+ AsyncHttpClient& operator=(const AsyncHttpClient&) = delete;
+
+ // -- Callback-based API ----------------------------------------------
+
+ void AsyncGet(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncPost(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncPost(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncPut(std::string_view Url,
+ const IoBuffer& Payload,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {});
+
+ // -- Future-based API ------------------------------------------------
+
+ [[nodiscard]] std::future<Response> Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ [[nodiscard]] std::future<Response> Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ [[nodiscard]] std::future<Response> Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Put(std::string_view Url,
+ const IoBuffer& Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ [[nodiscard]] std::future<Response> Put(std::string_view Url, const KeyValueMap& Parameters = {});
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+void asynchttpclient_test_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h
index ce646ebd7..9220a50b6 100644
--- a/src/zenhttp/include/zenhttp/httpclientauth.h
+++ b/src/zenhttp/include/zenhttp/httpclientauth.h
@@ -33,7 +33,8 @@ namespace httpclientauth {
std::string_view CloudHost,
bool Quiet,
bool Unattended,
- bool Hidden);
+ bool Hidden,
+ bool IsHordeUrl = false);
} // namespace httpclientauth
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
index f9a99f3cc..1d921600d 100644
--- a/src/zenhttp/include/zenhttp/httpcommon.h
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -19,8 +19,8 @@ class StringBuilderBase;
struct HttpRange
{
- uint32_t Start = ~uint32_t(0);
- uint32_t End = ~uint32_t(0);
+ uint64_t Start = ~uint64_t(0);
+ uint64_t End = ~uint64_t(0);
};
using HttpRanges = std::vector<HttpRange>;
@@ -30,6 +30,16 @@ extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeSt
std::string_view ReasonStringForHttpResultCode(int HttpCode);
bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges);
+struct MultipartByteRangesResult
+{
+ std::vector<IoBuffer> Parts;
+ std::string ContentType;
+};
+
+// Build a multipart/byteranges response body from the given data and ranges.
+// Generates a unique boundary per call. Returns empty Parts if any range is out of bounds.
+MultipartByteRangesResult BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges);
+
enum class HttpVerb : uint8_t
{
kGet = 1 << 0,
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 5eaed6004..955b8ed15 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -12,6 +12,7 @@
#include <zencore/string.h>
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <zenhttp/localrefpolicy.h>
#include <zentelemetry/hyperloglog.h>
#include <zentelemetry/stats.h>
@@ -121,11 +122,13 @@ public:
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload);
+ void WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs);
void WriteResponse(HttpResponseCode ResponseCode, CbObject Data);
void WriteResponse(HttpResponseCode ResponseCode, CbArray Array);
void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package);
void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString);
void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob);
+ void WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges);
virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0;
@@ -151,6 +154,8 @@ protected:
std::string_view m_QueryString;
mutable uint32_t m_RequestId = ~uint32_t(0);
mutable Oid m_SessionId = Oid::Zero;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
inline void SetIsHandled() { m_Flags |= kIsHandled; }
@@ -193,9 +198,16 @@ public:
HttpService() = default;
virtual ~HttpService() = default;
- virtual const char* BaseUri() const = 0;
- virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
- virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+ [[nodiscard]] virtual const char* BaseUri() const = 0;
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+
+ /// Whether this service accepts local file references in inbound packages from local clients.
+ [[nodiscard]] virtual bool AcceptsLocalFileReferences() const;
+
+ /// Returns the local ref policy for validating file paths in inbound local references.
+ /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed).
+ [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const;
// Internals
@@ -290,12 +302,12 @@ public:
std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; }
- /** Track active WebSocket connections — called by server implementations on upgrade/close. */
+ /** Track active WebSocket connections - called by server implementations on upgrade/close. */
void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); }
void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); }
uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); }
- /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */
+ /** Track WebSocket frame and byte counters - called by WS connection implementations per frame. */
void OnWebSocketFrameReceived(uint64_t Bytes)
{
m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed);
@@ -317,7 +329,7 @@ private:
int m_EffectiveHttpsPort = 0;
std::string m_ExternalHost;
metrics::Meter m_RequestMeter;
- metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error — sufficient for client counting
+ metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error - sufficient for client counting
metrics::HyperLogLog<12> m_ClientSessions;
std::string m_DefaultRedirect;
std::atomic<uint64_t> m_ActiveWebSocketConnections{0};
@@ -510,7 +522,8 @@ private:
bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
-void http_forcelink(); // internal
-void websocket_forcelink(); // internal
+void http_forcelink(); // internal
+void httpparser_forcelink(); // internal
+void websocket_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h
index bce771c75..51ab2e06e 100644
--- a/src/zenhttp/include/zenhttp/httpstats.h
+++ b/src/zenhttp/include/zenhttp/httpstats.h
@@ -23,11 +23,11 @@ namespace zen {
class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler
{
public:
- /// Construct without an io_context — optionally uses a dedicated push thread
+ /// Construct without an io_context - optionally uses a dedicated push thread
/// for WebSocket stats broadcasting.
explicit HttpStatsService(bool EnableWebSockets = false);
- /// Construct with an external io_context — uses an asio timer instead
+ /// Construct with an external io_context - uses an asio timer instead
/// of a dedicated thread for WebSocket stats broadcasting.
/// The caller must ensure the io_context outlives this service and that
/// its run loop is active.
@@ -43,7 +43,7 @@ public:
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
index 9c3b909a2..fd2f79171 100644
--- a/src/zenhttp/include/zenhttp/httpwsclient.h
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -26,7 +26,7 @@ namespace zen {
* Callback interface for WebSocket client events
*
* Separate from the server-side IWebSocketHandler because the caller
- * already owns the HttpWsClient — no Ref<WebSocketConnection> needed.
+ * already owns the HttpWsClient - no Ref<WebSocketConnection> needed.
*/
class IWsClientHandler
{
@@ -85,9 +85,9 @@ private:
/// it is treated as a plain host:port and gets the ws:// prefix.
///
/// Examples:
-/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws"
-/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo"
-/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar"
+/// HttpToWsUrl("http://host:8080", "/orch/ws") -> "ws://host:8080/orch/ws"
+/// HttpToWsUrl("https://host", "/foo") -> "wss://host/foo"
+/// HttpToWsUrl("host:8080", "/bar") -> "ws://host:8080/bar"
std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path);
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h
new file mode 100644
index 000000000..0b37f9dc7
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/localrefpolicy.h
@@ -0,0 +1,21 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <filesystem>
+
+namespace zen {
+
+/// Policy interface for validating local file reference paths in inbound CbPackage messages.
+/// Implementations should throw std::invalid_argument if the path is not allowed.
+class ILocalRefPolicy
+{
+public:
+ virtual ~ILocalRefPolicy() = default;
+
+ /// Validate that a local file reference path is allowed.
+ /// Throws std::invalid_argument if the path escapes the allowed root.
+ virtual void ValidatePath(const std::filesystem::path& Path) const = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
index 1a5068580..66e3f6e55 100644
--- a/src/zenhttp/include/zenhttp/packageformat.h
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -5,6 +5,7 @@
#include <zencore/compactbinarypackage.h>
#include <zencore/iobuffer.h>
#include <zencore/iohash.h>
+#include <zenhttp/localrefpolicy.h>
#include <functional>
#include <gsl/gsl-lite.hpp>
@@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
-CbPackage ParsePackageMessage(
- IoBuffer Payload,
- std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
+
+enum class ParseFlags
+{
+ kDefault = 0,
+ kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags);
+
+CbPackage ParsePackageMessage(
+ IoBuffer Payload,
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
return IoBuffer{Size};
- });
+ },
+ ParseFlags Flags = ParseFlags::kDefault,
+ const ILocalRefPolicy* Policy = nullptr);
bool IsPackageMessage(IoBuffer Payload);
bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
@@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe
class CbPackageReader
{
public:
- CbPackageReader();
+ CbPackageReader(ParseFlags Flags = ParseFlags::kDefault);
~CbPackageReader();
void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
+ void SetLocalRefPolicy(const ILocalRefPolicy* Policy);
/** Process compact binary package data stream
@@ -149,6 +162,8 @@ private:
kReadingBuffers
} m_CurrentState = State::kInitialState;
+ ParseFlags m_Flags;
+ const ILocalRefPolicy* m_LocalRefPolicy = nullptr;
std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
std::vector<IoBuffer> m_PayloadBuffers;
std::vector<CbAttachmentEntry> m_AttachmentEntries;
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
index 710579faa..2d25515d3 100644
--- a/src/zenhttp/include/zenhttp/websocket.h
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -59,7 +59,7 @@ class IWebSocketHandler
public:
virtual ~IWebSocketHandler() = default;
- virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0;
virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
};
diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp
index 7e6207e56..5ad5ebcc7 100644
--- a/src/zenhttp/monitoring/httpstats.cpp
+++ b/src/zenhttp/monitoring/httpstats.cpp
@@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request)
//
void
-HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen");
ZEN_INFO("Stats WebSocket client connected");
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
index 9c62c1f2d..267ce386c 100644
--- a/src/zenhttp/packageformat.cpp
+++ b/src/zenhttp/packageformat.cpp
@@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:");
typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t;
+/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security.
+/// If no policy is set, file-path local refs are rejected (fail-closed).
+static void
+ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path)
+{
+ if (Policy)
+ {
+ Policy->ValidatePath(Path);
+ }
+ else
+ {
+ throw std::invalid_argument("local file reference rejected: no validation policy");
+ }
+}
+
+// Validates the CbPackageHeader magic and attachment count. Returns the total
+// chunk count (AttachmentCount + 1, including the implicit root object).
+static uint32_t
+ValidatePackageHeader(const CbPackageHeader& Hdr)
+{
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ throw std::invalid_argument(
+ fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic));
+ }
+ // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against
+ // UINT32_MAX wrapping to 0, which would bypass subsequent size checks.
+ if (Hdr.AttachmentCount == UINT32_MAX)
+ {
+ throw std::invalid_argument("invalid CbPackage, attachment count overflow");
+ }
+ return Hdr.AttachmentCount + 1;
+}
+
+struct ValidatedLocalRef
+{
+ bool Valid = false;
+ const CbAttachmentReferenceHeader* Header = nullptr;
+ std::string_view Path;
+ std::string Error;
+};
+
+// Validates that the attachment buffer contains a well-formed local reference
+// header and path. On failure, Valid is false and Error contains the reason.
+static ValidatedLocalRef
+ValidateLocalRef(const IoBuffer& AttachmentBuffer)
+{
+ if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader))
+ {
+ return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())};
+ }
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+
+ if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)
+ {
+ return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})",
+ sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength,
+ AttachmentBuffer.Size())};
+ }
+
+ const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+ return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)};
+}
+
IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle);
std::vector<IoBuffer>
@@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload)
}
CbPackage
-ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
+ParsePackageMessage(IoBuffer Payload,
+ std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer,
+ ParseFlags Flags,
+ const ILocalRefPolicy* Policy)
{
ZEN_TRACE_CPU("ParsePackageMessage");
@@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
BinaryReader Reader(Payload);
- const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
- if (Hdr->HeaderMagic != kCbPkgMagic)
- {
- throw std::invalid_argument(
- fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic));
- }
+ const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
+ const uint32_t ChunkCount = ValidatePackageHeader(*Hdr);
Reader.Skip(sizeof(CbPackageHeader));
- const uint32_t ChunkCount = Hdr->AttachmentCount + 1;
-
- if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount)
+ // Widen to uint64_t so the multiplication cannot wrap on 32-bit.
+ const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount;
+ if (Reader.Remaining() < AttachmentTableSize)
{
throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)",
sizeof(CbAttachmentEntry) * ChunkCount,
@@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
{
- // Marshal local reference - a "pointer" to the chunk backing file
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+ if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences))
+ {
+ throw std::invalid_argument(
+ fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i));
+ }
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+ // Marshal local reference - a "pointer" to the chunk backing file
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
- std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
+ ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer);
+ if (!LocalRef.Valid)
+ {
+ MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash)));
+ continue;
+ }
+ const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header;
+ std::string_view PathView = LocalRef.Path;
IoBuffer FullFileBuffer;
@@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
}
else
{
+ ApplyLocalRefPolicy(Policy, Path);
FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
}
}
if (FullFileBuffer)
{
- IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
+ // Guard against offset+size overflow or exceeding the file bounds.
+ const uint64_t FileSize = FullFileBuffer.GetSize();
+ if (AttachRefHdr->PayloadByteOffset > FileSize ||
+ AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset)
+ {
+ MalformedAttachments.push_back(
+ std::make_pair(i,
+ fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}",
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ FileSize,
+ Entry.AttachmentHash)));
+ continue;
+ }
+
+ IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize
? FullFileBuffer
: IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
@@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa
return OutPackage.TryLoad(Response);
}
-CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
+CbPackageReader::CbPackageReader(ParseFlags Flags)
+: m_Flags(Flags)
+, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
{
}
@@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci
m_CreateBuffer = CreateBuffer;
}
+void
+CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy)
+{
+ m_LocalRefPolicy = Policy;
+}
+
uint64_t
CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
{
@@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
return sizeof m_PackageHeader;
case State::kReadingHeader:
- ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
- memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
- ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
- m_CurrentState = State::kReadingAttachmentEntries;
- m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
- return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
+ {
+ ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
+ memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
+ const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader);
+ m_CurrentState = State::kReadingAttachmentEntries;
+ m_AttachmentEntries.resize(ChunkCount);
+ return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry);
+ }
case State::kReadingAttachmentEntries:
ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
@@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
{
// Marshal local reference - a "pointer" to the chunk backing file
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
-
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+ ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer);
+ if (!LocalRef.Valid)
+ {
+ throw std::invalid_argument(LocalRef.Error);
+ }
- std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
+ const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header;
+ std::filesystem::path Path(Utf8ToWide(LocalRef.Path));
- std::filesystem::path Path{PathView};
+ if (!LocalRef.Path.starts_with(HandlePrefix))
+ {
+ ApplyLocalRefPolicy(m_LocalRefPolicy, Path);
+ }
IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
@@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
AttachRefHdr->PayloadByteSize));
}
+ // MakeFromFile silently clamps offset+size to the file size. Detect this
+ // to avoid returning a short buffer that could cause subtle downstream issues.
+ if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize)
+ {
+ throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})",
+ PathToUtf8(Path),
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ ChunkReference.GetSize()));
+ }
+
return ChunkReference;
};
@@ -732,6 +843,13 @@ CbPackageReader::Finalize()
{
IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
+ if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences))
+ {
+ throw std::invalid_argument(
+ fmt::format("package contains local reference (attachment #{}) but local references are not allowed",
+ CurrentAttachmentIndex));
+ }
+
if (CurrentAttachmentIndex == 0)
{
// Root object
@@ -815,6 +933,13 @@ CbPackageReader::Finalize()
TEST_SUITE_BEGIN("http.packageformat");
+/// Permissive policy that allows any path, for use in tests that exercise local ref
+/// functionality but are not testing path validation.
+struct PermissiveLocalRefPolicy : public ILocalRefPolicy
+{
+ void ValidatePath(const std::filesystem::path&) const override {}
+};
+
TEST_CASE("CbPackage.Serialization")
{
// Make a test package
@@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef")
RemainingBytes -= ByteCount;
};
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ CbPackageReader Reader(ParseFlags::kAllowLocalReferences);
+ Reader.SetLocalRefPolicy(&AllowAllPolicy);
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedHeader")
+{
+ // Payload too small for a CbPackageHeader
+ uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa};
+ IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.BadMagic")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = 0xDEADBEEF;
+ Hdr.AttachmentCount = 0;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.AttachmentCountOverflow")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = UINT32_MAX;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable")
+{
+ // Valid header but not enough data for the attachment entries
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = 10;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedAttachmentData")
+{
+ // Valid header + one attachment entry claiming more data than available
+ std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry));
+
+ CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data());
+ Hdr->HeaderMagic = kCbPkgMagic;
+ Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object)
+
+ CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader));
+ Entry->PayloadSize = 9999; // way more than available
+ Entry->Flags = CbAttachmentEntry::kIsObject;
+ Entry->AttachmentHash = IoHash();
+
+ IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size());
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault")
+{
+ // Build a valid package with local refs backed by compressed-format files,
+ // then verify it's rejected with default ParseFlags and accepted when allowed.
+ ScopedTemporaryDirectory TempDir;
+ auto Path1 = TempDir.Path() / "abcd";
+ auto Path2 = TempDir.Path() / "efgh";
+
+ // Compress data and write to disk, then create file-backed compressed attachments.
+ // The files must contain compressed-format data because ParsePackageMessage expects it
+ // when resolving local refs.
+ CompressedBuffer Comp1 =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ CompressedBuffer Comp2 =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+
+ IoHash Hash1 = Comp1.DecodeRawHash();
+ IoHash Hash2 = Comp2.DecodeRawHash();
+
+ {
+ IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer();
+ IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(Path1, Buf1);
+ WriteFile(Path2, Buf2);
+ }
+
+ // Create attachments from file-backed buffers so FormatPackageMessage uses local refs
+ CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1};
+ CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // Default flags should reject local refs
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+
+ // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip
+ // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader)
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy);
+ CHECK(Result.GetObject());
+ CHECK(Result.GetAttachments().size() == 2);
+}
+
+TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader")
+{
+ // Same test but via CbPackageReader
+ ScopedTemporaryDirectory TempDir;
+ auto FilePath = TempDir.Path() / "testdata";
+
+ {
+ IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata"));
+ WriteFile(FilePath, Buf);
+ }
+
+ IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath);
+ CbAttachment Attach{SharedBuffer(FileBuffer)};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ // Default flags should reject
CbPackageReader Reader;
uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
@@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef")
CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
}
- Reader.Finalize();
+ CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.BadMagicViaReader")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = 0xBADCAFE;
+ Hdr.AttachmentCount = 0;
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = UINT32_MAX;
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot")
+{
+ // A file outside the allowed root should be rejected by the policy
+ ScopedTemporaryDirectory AllowedRoot;
+ ScopedTemporaryDirectory OutsideDir;
+
+ auto OutsidePath = OutsideDir.Path() / "secret.dat";
+ {
+ IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret"));
+ WriteFile(OutsidePath, Buf);
+ }
+
+ // Create file-backed compressed attachment from outside root
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(OutsidePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // Policy rooted at AllowedRoot should reject the file in OutsideDir
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string();
+
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot")
+{
+ // A file inside the allowed root should be accepted by the policy
+ ScopedTemporaryDirectory TempRoot;
+
+ auto FilePath = TempRoot.Path() / "data.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(FilePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string();
+
+ CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy);
+ CHECK(Result.GetObject());
+ CHECK(Result.GetAttachments().size() == 1);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal")
+{
+ // A file path containing ".." that resolves outside root should be rejected
+ ScopedTemporaryDirectory TempRoot;
+ ScopedTemporaryDirectory OutsideDir;
+
+ auto OutsidePath = OutsideDir.Path() / "evil.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(OutsidePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ // Root is TempRoot, but the file lives in OutsideDir
+ Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string();
+
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed")
+{
+ // When local refs are allowed but no policy is provided, file-path refs should be rejected
+ ScopedTemporaryDirectory TempDir;
+
+ auto FilePath = TempDir.Path() / "data.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(FilePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // kAllowLocalReferences but nullptr policy => fail-closed
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument);
}
TEST_SUITE_END();
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 7972777b8..a1a775ba3 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -625,6 +625,8 @@ public:
void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
/**
* Initialize the response for sending a payload made up of multiple blobs
@@ -768,10 +770,18 @@ public:
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ std::string_view ContentTypeStr =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Type: " << ContentTypeStr << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv;
+ }
+
if (!m_IsKeepAlive)
{
m_Headers << "Connection: close\r\n"sv;
@@ -898,7 +908,9 @@ private:
bool m_AllowZeroCopyFileSend = true;
State m_State = State::kUninitialized;
HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
asio::buffer(ResponseStr->data(), ResponseStr->size()),
asio::bind_executor(
m_Strand,
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()](
+ const asio::error_code& Ec,
+ std::size_t) {
if (Ec)
{
ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
@@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ std::string_view FullUrl = Conn->m_RequestData.Url();
+ std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
}));
@@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest()
return;
}
}
- // Service doesn't support WebSocket or missing key — fall through to normal handling
+ // Service doesn't support WebSocket or missing key - fall through to normal handling
}
if (!m_RequestData.IsKeepAlive())
@@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->SetKeepAlive(m_Request.IsKeepAlive());
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->SetKeepAlive(m_Request.IsKeepAlive());
+ if (!m_ContentTypeOverride.empty())
+ {
+ m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 918b55dc6..8b07c7905 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -8,6 +8,13 @@
#include <limits>
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <cstring>
+# include <string>
+# include <string_view>
+#endif
+
namespace zen {
using namespace std::literals;
@@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W
// HttpRequestParser
//
-http_parser_settings HttpRequestParser::s_ParserSettings{
- .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
- .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
- .on_status =
- [](http_parser* p, const char* Data, size_t ByteCount) {
- ZEN_UNUSED(p, Data, ByteCount);
- return 0;
- },
- .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
- .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
- .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
- .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
- .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
- .on_chunk_header{},
- .on_chunk_complete{}};
+// clang-format off
+llhttp_settings_t HttpRequestParser::s_ParserSettings = []() {
+ llhttp_settings_t S;
+ llhttp_settings_init(&S);
+ S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); };
+ S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); };
+ S.on_status = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); };
+ S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); };
+ S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); };
+ S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); };
+ S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); };
+ return S;
+}();
+// clang-format on
HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection)
{
- http_parser_init(&m_Parser, HTTP_REQUEST);
+ llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings);
m_Parser.data = this;
ResetState();
@@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser()
size_t
HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize)
{
- const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
-
- http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
-
- if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
+ llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize);
+ if (Err == HPE_OK)
{
- ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
- return ~0ull;
+ return DataSize;
}
- return ConsumedBytes;
+ if (Err == HPE_PAUSED_UPGRADE)
+ {
+ return DataSize;
+ }
+ ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser));
+ return ~0ull;
}
int
@@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
if (m_UrlRange.Length == 0)
@@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
if (m_HeaderEntries.empty())
@@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
ZEN_ASSERT_SLOW(!m_HeaderEntries.empty());
@@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete()
}
}
- m_KeepAlive = !!http_should_keep_alive(&m_Parser);
+ m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser);
- switch (m_Parser.method)
+ switch (llhttp_get_method(&m_Parser))
{
case HTTP_GET:
m_RequestVerb = HttpVerb::kGet;
@@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete()
break;
default:
- ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
+ ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser))));
break;
}
@@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes",
(m_BodyPosition + Bytes) - m_BodyBuffer.Size());
- return 1;
+ return -1;
}
memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
m_BodyPosition += Bytes;
- if (http_body_is_final(&m_Parser))
- {
- if (m_BodyPosition != m_BodyBuffer.Size())
- {
- ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
- return 1;
- }
- }
-
return 0;
}
@@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete()
catch (const AssertException& AssertEx)
{
ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription());
- return 1;
+ return -1;
}
catch (const std::system_error& SystemError)
{
@@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete()
ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value());
}
ResetState();
- return 1;
+ return -1;
}
catch (const std::bad_alloc& BadAlloc)
{
ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what());
ResetState();
- return 1;
+ return -1;
}
catch (const std::exception& Ex)
{
ZEN_ERROR("failed processing http request: '{}'", Ex.what());
ResetState();
- return 1;
+ return -1;
}
}
@@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const
return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
}
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace {
+
+ struct MockCallbacks : HttpRequestParserCallbacks
+ {
+ int HandleRequestCount = 0;
+ int TerminateCount = 0;
+
+ HttpRequestParser* Parser = nullptr;
+
+ HttpVerb LastVerb{};
+ std::string LastUrl;
+ std::string LastQueryString;
+ std::string LastBody;
+ bool LastKeepAlive = false;
+ bool LastIsWebSocketUpgrade = false;
+ std::string LastSecWebSocketKey;
+ std::string LastUpgradeHeader;
+ HttpContentType LastContentType{};
+
+ void HandleRequest() override
+ {
+ ++HandleRequestCount;
+ if (Parser)
+ {
+ LastVerb = Parser->RequestVerb();
+ LastUrl = std::string(Parser->Url());
+ LastQueryString = std::string(Parser->QueryString());
+ LastKeepAlive = Parser->IsKeepAlive();
+ LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade();
+ LastSecWebSocketKey = std::string(Parser->SecWebSocketKey());
+ LastUpgradeHeader = std::string(Parser->UpgradeHeader());
+ LastContentType = Parser->ContentType();
+
+ IoBuffer Body = Parser->Body();
+ if (Body.Size() > 0)
+ {
+ LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size());
+ }
+ else
+ {
+ LastBody.clear();
+ }
+ }
+ }
+
+ void TerminateConnection() override { ++TerminateCount; }
+ };
+
+} // anonymous namespace
+
+TEST_SUITE_BEGIN("http.httpparser");
+
+TEST_CASE("httpparser.basic_get")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, HttpVerb::kGet);
+ CHECK_EQ(Mock.LastUrl, "/path");
+ CHECK(Mock.LastKeepAlive);
+}
+
+TEST_CASE("httpparser.post_with_body")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 13\r\n"
+ "Content-Type: application/json\r\n"
+ "\r\n"
+ "{\"key\":\"val\"}";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, HttpVerb::kPost);
+ CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}");
+ CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON);
+}
+
+TEST_CASE("httpparser.pipelined_requests")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n"
+ "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 2);
+ CHECK_EQ(Mock.LastUrl, "/second");
+}
+
+TEST_CASE("httpparser.partial_header")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc";
+ std::string Chunk2 = "alhost\r\n\r\n";
+
+ size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size());
+ CHECK_NE(Consumed1, ~0ull);
+ CHECK_EQ(Consumed1, Chunk1.size());
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+
+ size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size());
+ CHECK_NE(Consumed2, ~0ull);
+ CHECK_EQ(Consumed2, Chunk2.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastUrl, "/path");
+}
+
+TEST_CASE("httpparser.partial_body")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Headers =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 10\r\n"
+ "\r\n";
+ std::string BodyPart1 = "hello";
+ std::string BodyPart2 = "world";
+
+ std::string Chunk1 = Headers + BodyPart1;
+
+ size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size());
+ CHECK_NE(Consumed1, ~0ull);
+ CHECK_EQ(Consumed1, Chunk1.size());
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+
+ size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size());
+ CHECK_NE(Consumed2, ~0ull);
+ CHECK_EQ(Consumed2, BodyPart2.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastBody, "helloworld");
+}
+
+TEST_CASE("httpparser.invalid_request")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Garbage = "NOT_HTTP garbage data\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size());
+ CHECK_EQ(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+}
+
+TEST_CASE("httpparser.body_overflow")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes,
+ // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY"
+ // as a new HTTP request which fails.
+ std::string Request =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 2\r\n"
+ "\r\n"
+ "TOO_LONG_BODY";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastBody, "TO");
+}
+
+TEST_CASE("httpparser.websocket_upgrade")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "GET /ws HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ "Sec-WebSocket-Version: 13\r\n"
+ "\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK(Mock.LastIsWebSocketUpgrade);
+ CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(Mock.LastUpgradeHeader, "websocket");
+}
+
+TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string HttpPart =
+ "GET /ws HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ "Sec-WebSocket-Version: 13\r\n"
+ "\r\n";
+
+ // Append fake WebSocket frame bytes after the HTTP message
+ std::string Request = HttpPart;
+ Request.push_back('\x81');
+ Request.push_back('\x05');
+ Request.append("hello");
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_NE(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK(Mock.LastIsWebSocketUpgrade);
+}
+
+TEST_CASE("httpparser.keep_alive_detection")
+{
+ SUBCASE("HTTP/1.1 default keep-alive")
+ {
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+ Parser.ConsumeData(Request.data(), Request.size());
+ CHECK(Mock.LastKeepAlive);
+ }
+
+ SUBCASE("Connection: close disables keep-alive")
+ {
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
+ Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_FALSE(Mock.LastKeepAlive);
+ }
+}
+
+TEST_CASE("httpparser.all_verbs")
+{
+ struct VerbTest
+ {
+ const char* Method;
+ HttpVerb Expected;
+ };
+
+ VerbTest Tests[] = {
+ {"GET", HttpVerb::kGet},
+ {"POST", HttpVerb::kPost},
+ {"PUT", HttpVerb::kPut},
+ {"DELETE", HttpVerb::kDelete},
+ {"HEAD", HttpVerb::kHead},
+ {"COPY", HttpVerb::kCopy},
+ {"OPTIONS", HttpVerb::kOptions},
+ };
+
+ for (const VerbTest& Test : Tests)
+ {
+ CAPTURE(Test.Method);
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, Test.Expected);
+ }
+}
+
+TEST_CASE("httpparser.query_string")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastUrl, "/path");
+ CHECK_EQ(Mock.LastQueryString, "key=val&other=123");
+}
+
+TEST_SUITE_END();
+
+void
+httpparser_forcelink()
+{
+}
+
+#endif // ZEN_WITH_TESTS
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 23ad9d8fb..4ff216248 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -8,7 +8,7 @@
#include <EASTL/fixed_vector.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <http_parser.h>
+#include <llhttp.h>
ZEN_THIRD_PARTY_INCLUDES_END
#include <atomic>
@@ -100,7 +100,7 @@ private:
Oid m_SessionId{};
IoBuffer m_BodyBuffer;
uint64_t m_BodyPosition = 0;
- http_parser m_Parser;
+ llhttp_t m_Parser;
eastl::fixed_vector<char, 512> m_HeaderData;
std::string m_NormalizedUrl;
@@ -114,8 +114,8 @@ private:
int OnBody(const char* Data, size_t Bytes);
int OnMessageComplete();
- static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); }
- static http_parser_settings s_ParserSettings;
+ static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); }
+ static llhttp_settings_t s_ParserSettings;
};
} // namespace zen
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 31b0315d4..b0fb020e0 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -185,13 +185,17 @@ public:
const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; }
void SuppressPayload() { m_ResponseBuffers.resize(1); }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
std::string_view GetHeaders();
private:
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
uint64_t m_ContentLength = 0;
std::vector<IoBuffer> m_ResponseBuffers;
ExtendableStringBuilder<160> m_Headers;
@@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders()
if (m_Headers.Size() == 0)
{
+ std::string_view ContentTypeStr =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Type: " << ContentTypeStr << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv;
+ }
+
if (!m_IsKeepAlive)
{
m_Headers << "Connection: close\r\n"sv;
@@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary));
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten
ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(ContentType));
+ if (!m_ContentTypeOverride.empty())
+ {
+ m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 2cad97725..67b1230a0 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -464,6 +464,8 @@ public:
inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -473,6 +475,8 @@ private:
uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
std::string m_LocationHeader;
@@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
- std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
+ std::string_view ContentTypeString =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
@@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
}
+ // Content-Range header (for 206 Partial Content single-range responses)
+
+ if (!m_ContentRangeHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange];
+ ContentRangeHeader->pRawValue = m_ContentRangeHeader.data();
+ ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort)
else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
{
// Port may be owned by another process's wildcard registration (access denied)
- // or actively in use (sharing violation) — retry on a different port
+ // or actively in use (sharing violation) - retry on a different port
ShouldRetryNextPort = true;
}
else
@@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ if (!m_ContentRangeHeader.empty())
+ {
+ Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
+
if (SuppressBody())
{
Response->SuppressResponseBody();
@@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ if (!m_ContentTypeOverride.empty())
+ {
+ Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
+
if (SuppressBody())
{
Response->SuppressResponseBody();
@@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
&Transaction().Server()));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ ExtendableStringBuilder<128> UrlUtf8;
+ WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath,
+ gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))},
+ UrlUtf8);
+ int PrefixLen = Service->UriPrefixLength();
+ std::string_view RelativeUri{UrlUtf8.ToView()};
+ RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
return nullptr;
@@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
- // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // WebSocket upgrade failed - return nullptr since ServerRequest()
// was never populated (no InvokeRequestHandler call)
return nullptr;
}
- // Service doesn't support WebSocket or missing key — fall through to normal handling
+ // Service doesn't support WebSocket or missing key - fall through to normal handling
}
}
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index 5ae48f5b3..078c21ea1 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData()
}
case WebSocketOpcode::kPong:
- // Unsolicited pong — ignore per RFC 6455
+ // Unsolicited pong - ignore per RFC 6455
break;
case WebSocketOpcode::kClose:
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
index af320172d..8520e9f60 100644
--- a/src/zenhttp/servers/wshttpsys.cpp
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown()
return;
}
- // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED
HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
}
@@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData()
}
case WebSocketOpcode::kPong:
- // Unsolicited pong — ignore per RFC 6455
+ // Unsolicited pong - ignore per RFC 6455
break;
case WebSocketOpcode::kClose:
@@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
m_Handler.OnWebSocketClose(*this, Code, Reason);
- // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED
HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
}
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 59c46a418..a58037fec 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -5,6 +5,7 @@
# include <zencore/scopeguard.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zencore/timer.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/httpwsclient.h>
@@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec")
std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
- // Server frames are unmasked — TryParseFrame should handle them
+ // Server frames are unmasked - TryParseFrame should handle them
WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
CHECK(Result.IsValid);
@@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec")
{
std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
- // Pass only 1 byte — not enough for a frame header
+ // Pass only 1 byte - not enough for a frame header
WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
CHECK_FALSE(Result.IsValid);
CHECK_EQ(Result.BytesConsumed, 0u);
@@ -335,8 +336,9 @@ namespace {
}
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override
{
+ ZEN_UNUSED(RelativeUri);
m_OpenCount.fetch_add(1);
m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
@@ -463,7 +465,7 @@ namespace {
if (!Done.load())
{
- // Timeout — cancel the read
+ // Timeout - cancel the read
asio::error_code Ec;
Sock.cancel(Ec);
}
@@ -476,6 +478,23 @@ namespace {
return Result;
}
+ static void WaitForServerListening(int Port)
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Probe(IoCtx);
+ asio::error_code Ec;
+ Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec);
+ if (!Ec)
+ {
+ return;
+ }
+ Sleep(10);
+ }
+ }
+
} // anonymous namespace
TEST_CASE("websocket.integration")
@@ -501,8 +520,8 @@ TEST_CASE("websocket.integration")
Server->Close();
});
- // Give server a moment to start accepting
- Sleep(100);
+ // Wait for server to start accepting
+ WaitForServerListening(Port);
SUBCASE("handshake succeeds with 101")
{
@@ -692,7 +711,7 @@ TEST_CASE("websocket.integration")
std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
- // Should NOT get 101 — should fall through to normal request handling
+ // Should NOT get 101 - should fall through to normal request handling
CHECK(Response.find("101") == std::string::npos);
Sock.close();
@@ -813,7 +832,7 @@ TEST_CASE("websocket.client")
Server->Close();
});
- Sleep(100);
+ WaitForServerListening(Port);
SUBCASE("connect, echo, close")
{
@@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket")
Server->Close();
});
- Sleep(100);
+ WaitForServerListening(Port);
SUBCASE("connect, echo, close over unix socket")
{
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 7b050ae35..67a01403d 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -9,7 +9,7 @@ target('zenhttp')
add_files("servers/wshttpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
add_deps("zencore", "zentelemetry", "transport-sdk", "asio")
- add_packages("http_parser", "json11", "libcurl")
+ add_packages("llhttp", "json11", "libcurl")
add_options("httpsys")
if is_plat("linux", "macosx") then
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index 3ac8eea8d..e15aa4d30 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -4,6 +4,7 @@
#if ZEN_WITH_TESTS
+# include <zenhttp/asynchttpclient.h>
# include <zenhttp/httpclient.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/packageformat.h>
@@ -16,7 +17,9 @@ zenhttp_forcelinktests()
{
http_forcelink();
httpclient_forcelink();
+ httpparser_forcelink();
httpclient_test_forcelink();
+ asynchttpclient_test_forcelink();
forcelink_packageformat();
passwordsecurity_forcelink();
websocket_forcelink();
diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h
index 0a3411ace..cebf217e1 100644
--- a/src/zennomad/include/zennomad/nomadclient.h
+++ b/src/zennomad/include/zennomad/nomadclient.h
@@ -52,7 +52,11 @@ public:
/** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint.
* The JSON structure varies based on the configured driver and distribution mode. */
- std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const;
+ std::string BuildJobJson(const std::string& JobId,
+ const std::string& OrchestratorEndpoint,
+ const std::string& CoordinatorSession = {},
+ bool CleanStart = false,
+ const std::string& TraceHost = {}) const;
/** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */
bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob);
diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h
index 750693b3f..a8368e3dc 100644
--- a/src/zennomad/include/zennomad/nomadprovisioner.h
+++ b/src/zennomad/include/zennomad/nomadprovisioner.h
@@ -47,7 +47,11 @@ public:
/** Construct a provisioner.
* @param Config Nomad connection and job configuration.
* @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */
- NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint);
+ NomadProvisioner(const NomadConfig& Config,
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession = {},
+ bool CleanStart = false,
+ std::string_view TraceHost = {});
/** Signals the management thread to exit and stops all tracked jobs. */
~NomadProvisioner();
@@ -83,6 +87,9 @@ private:
NomadConfig m_Config;
std::string m_OrchestratorEndpoint;
+ std::string m_CoordinatorSession;
+ bool m_CleanStart = false;
+ std::string m_TraceHost;
std::unique_ptr<NomadClient> m_Client;
diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp
index 9edcde125..4bb09a930 100644
--- a/src/zennomad/nomadclient.cpp
+++ b/src/zennomad/nomadclient.cpp
@@ -58,7 +58,11 @@ NomadClient::Initialize()
}
std::string
-NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const
+NomadClient::BuildJobJson(const std::string& JobId,
+ const std::string& OrchestratorEndpoint,
+ const std::string& CoordinatorSession,
+ bool CleanStart,
+ const std::string& TraceHost) const
{
ZEN_TRACE_CPU("NomadClient::BuildJobJson");
@@ -94,6 +98,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra
IdArg << "--instance-id=nomad-" << JobId;
Args.push_back(std::string(IdArg.ToView()));
}
+ if (!CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << CoordinatorSession;
+ Args.push_back(std::string(SessionArg.ToView()));
+ }
+ if (CleanStart)
+ {
+ Args.push_back("--clean");
+ }
+ if (!TraceHost.empty())
+ {
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << TraceHost;
+ Args.push_back(std::string(TraceArg.ToView()));
+ }
TaskConfig["args"] = Args;
}
else
@@ -115,6 +135,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra
IdArg << "--instance-id=nomad-" << JobId;
Args.push_back(std::string(IdArg.ToView()));
}
+ if (!CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << CoordinatorSession;
+ Args.push_back(std::string(SessionArg.ToView()));
+ }
+ if (CleanStart)
+ {
+ Args.push_back("--clean");
+ }
+ if (!TraceHost.empty())
+ {
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << TraceHost;
+ Args.push_back(std::string(TraceArg.ToView()));
+ }
TaskConfig["args"] = Args;
}
diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp
index 1ae968fb7..deecdef05 100644
--- a/src/zennomad/nomadprocess.cpp
+++ b/src/zennomad/nomadprocess.cpp
@@ -37,7 +37,7 @@ struct NomadProcess::Impl
}
CreateProcOptions Options;
- Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options);
diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp
index 3fe9c0ac3..e07ce155e 100644
--- a/src/zennomad/nomadprovisioner.cpp
+++ b/src/zennomad/nomadprovisioner.cpp
@@ -14,9 +14,16 @@
namespace zen::nomad {
-NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint)
+NomadProvisioner::NomadProvisioner(const NomadConfig& Config,
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession,
+ bool CleanStart,
+ std::string_view TraceHost)
: m_Config(Config)
, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_CoordinatorSession(CoordinatorSession)
+, m_CleanStart(CleanStart)
+, m_TraceHost(TraceHost)
, m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId()))
, m_Log(zen::logging::Get("nomad.provisioner"))
{
@@ -154,7 +161,7 @@ NomadProvisioner::SubmitNewJobs()
ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load());
- const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint);
+ const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint, m_CoordinatorSession, m_CleanStart, m_TraceHost);
NomadJobInfo JobInfo;
JobInfo.Id = JobId;
diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp
index 0e0b14dca..40ea757eb 100644
--- a/src/zenremotestore/builds/buildstoragecache.cpp
+++ b/src/zenremotestore/builds/buildstoragecache.cpp
@@ -96,7 +96,8 @@ public:
ZEN_ASSERT(!IsFlushed);
ZEN_ASSERT(ContentType == ZenContentType::kCompressedBinary);
- // Move all segments in Payload to be file handle based so if Payload is materialized it does not affect buffers in queue
+ // Move all segments in Payload to be file handle based unless they are very small so if Payload is materialized it does not affect
+ // buffers in queue
std::vector<SharedBuffer> FileBasedSegments;
std::span<const SharedBuffer> Segments = Payload.GetSegments();
FileBasedSegments.reserve(Segments.size());
@@ -104,42 +105,56 @@ public:
tsl::robin_map<void*, std::filesystem::path> HandleToPath;
for (const SharedBuffer& Segment : Segments)
{
- std::filesystem::path FilePath;
- IoBufferFileReference Ref;
- if (Segment.AsIoBuffer().GetFileReference(Ref))
+ const uint64_t SegmentSize = Segment.GetSize();
+ if (SegmentSize < 16u * 1024u)
{
- if (auto It = HandleToPath.find(Ref.FileHandle); It != HandleToPath.end())
- {
- FilePath = It->second;
- }
- else
+ FileBasedSegments.push_back(Segment);
+ }
+ else
+ {
+ std::filesystem::path FilePath;
+ IoBufferFileReference Ref;
+ if (Segment.AsIoBuffer().GetFileReference(Ref))
{
- std::error_code Ec;
- std::filesystem::path Path = PathFromHandle(Ref.FileHandle, Ec);
- if (!Ec && !Path.empty())
+ if (auto It = HandleToPath.find(Ref.FileHandle); It != HandleToPath.end())
{
- HandleToPath.insert_or_assign(Ref.FileHandle, Path);
- FilePath = std::move(Path);
+ FilePath = It->second;
+ }
+ else
+ {
+ std::error_code Ec;
+ std::filesystem::path Path = PathFromHandle(Ref.FileHandle, Ec);
+ if (!Ec && !Path.empty())
+ {
+ HandleToPath.insert_or_assign(Ref.FileHandle, Path);
+ FilePath = std::move(Path);
+ }
+ else
+ {
+ ZEN_WARN("Failed getting path for chunk to upload to cache. Skipping upload.");
+ return;
+ }
}
}
- }
- if (!FilePath.empty())
- {
- IoBuffer BufferFromFile = IoBufferBuilder::MakeFromFile(FilePath, Ref.FileChunkOffset, Ref.FileChunkSize);
- if (BufferFromFile)
+ if (!FilePath.empty())
{
- FileBasedSegments.push_back(SharedBuffer(std::move(BufferFromFile)));
+ IoBuffer BufferFromFile = IoBufferBuilder::MakeFromFile(FilePath, Ref.FileChunkOffset, Ref.FileChunkSize);
+ if (BufferFromFile)
+ {
+ FileBasedSegments.push_back(SharedBuffer(std::move(BufferFromFile)));
+ }
+ else
+ {
+ ZEN_WARN("Failed opening file '{}' to upload to cache. Skipping upload.", FilePath);
+ return;
+ }
}
else
{
FileBasedSegments.push_back(Segment);
}
}
- else
- {
- FileBasedSegments.push_back(Segment);
- }
}
}
diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp
index a04063c4c..6d93d8de6 100644
--- a/src/zenremotestore/builds/buildstorageoperations.cpp
+++ b/src/zenremotestore/builds/buildstorageoperations.cpp
@@ -11,8 +11,7 @@
#include <zenremotestore/chunking/chunkblock.h>
#include <zenremotestore/chunking/chunkingcache.h>
#include <zenremotestore/chunking/chunkingcontroller.h>
-#include <zenremotestore/filesystemutils.h>
-#include <zenremotestore/operationlogoutput.h>
+#include <zenutil/progress.h>
#include <zencore/basicfile.h>
#include <zencore/compactbinary.h>
@@ -26,6 +25,7 @@
#include <zencore/string.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
+#include <zenutil/filesystemutils.h>
#include <zenutil/wildcard.h>
#include <numeric>
@@ -79,7 +79,8 @@ namespace {
return CacheFolderPath / RawHash.ToHexString();
}
- bool CleanDirectory(OperationLogOutput& OperationLogOutput,
+ bool CleanDirectory(LoggerRef InLog,
+ ProgressBase& Progress,
WorkerThreadPool& IOWorkerPool,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -88,10 +89,10 @@ namespace {
std::span<const std::string> ExcludeDirectories)
{
ZEN_TRACE_CPU("CleanDirectory");
+ ZEN_SCOPED_LOG(InLog);
Stopwatch Timer;
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(OperationLogOutput.CreateProgressBar("Clean Folder"));
- OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = Progress.CreateProgressBar("Clean Folder");
CleanDirectoryResult Result = CleanDirectory(
IOWorkerPool,
@@ -100,16 +101,16 @@ namespace {
Path,
ExcludeDirectories,
[&](const std::string_view Details, uint64_t TotalCount, uint64_t RemainingCount, bool IsPaused, bool IsAborted) {
- Progress.UpdateState({.Task = "Cleaning folder ",
- .Details = std::string(Details),
- .TotalCount = TotalCount,
- .RemainingCount = RemainingCount,
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ProgressBar->UpdateState({.Task = "Cleaning folder ",
+ .Details = std::string(Details),
+ .TotalCount = TotalCount,
+ .RemainingCount = RemainingCount,
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
},
- OperationLogOutput.GetProgressUpdateDelayMS());
+ Progress.GetProgressUpdateDelayMS());
- Progress.Finish();
+ ProgressBar->Finish();
if (AbortFlag)
{
@@ -128,17 +129,16 @@ namespace {
Result.FailedRemovePaths[FailedPathIndex].second.value(),
Result.FailedRemovePaths[FailedPathIndex].second.message());
}
- ZEN_OPERATION_LOG_WARN(OperationLogOutput, "Clean failed to remove files from '{}': {}", Path, SB.ToView());
+ ZEN_WARN("Clean failed to remove files from '{}': {}", Path, SB.ToView());
}
if (ElapsedTimeMs >= 200 && !IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(OperationLogOutput,
- "Wiped folder '{}' {} ({}) in {}",
- Path,
- Result.FoundCount,
- NiceBytes(Result.DeletedByteCount),
- NiceTimeSpanMs(ElapsedTimeMs));
+ ZEN_INFO("Wiped folder '{}' {} ({}) in {}",
+ Path,
+ Result.FoundCount,
+ NiceBytes(Result.DeletedByteCount),
+ NiceTimeSpanMs(ElapsedTimeMs));
}
return Result.FailedRemovePaths.empty();
@@ -150,11 +150,9 @@ namespace {
}
bool IsChunkCompressable(const tsl::robin_set<uint32_t>& NonCompressableExtensionHashes,
- const ChunkedFolderContent& Content,
const ChunkedContentLookup& Lookup,
uint32_t ChunkIndex)
{
- ZEN_UNUSED(Content);
const uint32_t ChunkLocationCount = Lookup.ChunkSequenceLocationCounts[ChunkIndex];
if (ChunkLocationCount == 0)
{
@@ -180,6 +178,54 @@ namespace {
return SB.ToString();
}
+ uint32_t SetNativeFileAttributes(const std::filesystem::path FilePath, SourcePlatform SourcePlatform, uint32_t Attributes)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ if (SourcePlatform == SourcePlatform::Windows)
+ {
+ SetFileAttributesToPath(FilePath, Attributes);
+ return Attributes;
+ }
+ else
+ {
+ uint32_t CurrentAttributes = GetFileAttributesFromPath(FilePath);
+ uint32_t NewAttributes = zen::MakeFileAttributeReadOnly(CurrentAttributes, zen::IsFileModeReadOnly(Attributes));
+ if (CurrentAttributes != NewAttributes)
+ {
+ SetFileAttributesToPath(FilePath, NewAttributes);
+ }
+ return NewAttributes;
+ }
+#endif // ZEN_PLATFORM_WINDOWS
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ if (SourcePlatform != SourcePlatform::Windows)
+ {
+ zen::SetFileMode(FilePath, Attributes);
+ return Attributes;
+ }
+ else
+ {
+ uint32_t CurrentMode = zen::GetFileMode(FilePath);
+ uint32_t NewMode = zen::MakeFileModeReadOnly(CurrentMode, zen::IsFileAttributeReadOnly(Attributes));
+ if (CurrentMode != NewMode)
+ {
+ zen::SetFileMode(FilePath, NewMode);
+ }
+ return NewMode;
+ }
+#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ };
+
+ uint32_t GetNativeFileAttributes(const std::filesystem::path FilePath)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ return GetFileAttributesFromPath(FilePath);
+#endif // ZEN_PLATFORM_WINDOWS
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ return GetFileMode(FilePath);
+#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ }
+
void DownloadLargeBlob(BuildStorageBase& Storage,
const std::filesystem::path& DownloadFolder,
const Oid& BuildId,
@@ -219,7 +265,7 @@ namespace {
Workload->TempFile.Write(Chunk.GetView(), Offset);
}
},
- [&Work, Workload, &DownloadedChunkByteCount, OnDownloadComplete = std::move(OnDownloadComplete)]() {
+ [&Work, Workload, OnDownloadComplete = std::move(OnDownloadComplete)]() {
if (!Work.IsAborted())
{
ZEN_TRACE_CPU("Async_DownloadLargeBlob_OnComplete");
@@ -334,8 +380,120 @@ namespace {
return CompositeBuffer{};
}
+ std::filesystem::path TryMoveDownloadedChunk(IoBuffer& BlockBuffer, const std::filesystem::path& Path, bool ForceDiskBased)
+ {
+ uint64_t BlockSize = BlockBuffer.GetSize();
+ IoBufferFileReference FileRef;
+ if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == BlockSize))
+ {
+ ZEN_TRACE_CPU("MoveTempFullBlock");
+ std::error_code Ec;
+ std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec);
+ if (!Ec)
+ {
+ BlockBuffer.SetDeleteOnClose(false);
+ BlockBuffer = {};
+ RenameFile(TempBlobPath, Path, Ec);
+ if (Ec)
+ {
+ // Re-open the temp file again
+ BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete);
+ BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true);
+ BlockBuffer.SetDeleteOnClose(true);
+ }
+ else
+ {
+ return Path;
+ }
+ }
+ }
+
+ if (ForceDiskBased)
+ {
+ // Could not be moved and rather large, lets store it on disk
+ ZEN_TRACE_CPU("WriteTempFullBlock");
+ TemporaryFile::SafeWriteFile(Path, BlockBuffer);
+ BlockBuffer = {};
+ return Path;
+ }
+
+ return {};
+ }
+
} // namespace
+class ReadFileCache
+{
+public:
+ // A buffered file reader that provides CompositeBuffer where the buffers are owned and the memory never overwritten
+ ReadFileCache(std::atomic<uint64_t>& OpenReadCount,
+ std::atomic<uint64_t>& CurrentOpenFileCount,
+ std::atomic<uint64_t>& ReadCount,
+ std::atomic<uint64_t>& ReadByteCount,
+ const std::filesystem::path& Path,
+ const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ size_t MaxOpenFileCount)
+ : m_Path(Path)
+ , m_LocalContent(LocalContent)
+ , m_LocalLookup(LocalLookup)
+ , m_OpenReadCount(OpenReadCount)
+ , m_CurrentOpenFileCount(CurrentOpenFileCount)
+ , m_ReadCount(ReadCount)
+ , m_ReadByteCount(ReadByteCount)
+ {
+ m_OpenFiles.reserve(MaxOpenFileCount);
+ }
+ ~ReadFileCache() { m_OpenFiles.clear(); }
+
+ CompositeBuffer GetRange(uint32_t SequenceIndex, uint64_t Offset, uint64_t Size)
+ {
+ ZEN_TRACE_CPU("ReadFileCache::GetRange");
+
+ auto CacheIt =
+ std::find_if(m_OpenFiles.begin(), m_OpenFiles.end(), [SequenceIndex](const auto& Lhs) { return Lhs.first == SequenceIndex; });
+ if (CacheIt != m_OpenFiles.end())
+ {
+ if (CacheIt != m_OpenFiles.begin())
+ {
+ auto CachedFile(std::move(CacheIt->second));
+ m_OpenFiles.erase(CacheIt);
+ m_OpenFiles.insert(m_OpenFiles.begin(), std::make_pair(SequenceIndex, std::move(CachedFile)));
+ }
+ CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size);
+ return Result;
+ }
+ const uint32_t LocalPathIndex = m_LocalLookup.SequenceIndexFirstPathIndex[SequenceIndex];
+ const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred();
+ if (Size == m_LocalContent.RawSizes[LocalPathIndex])
+ {
+ IoBuffer Result = IoBufferBuilder::MakeFromFile(LocalFilePath);
+ return CompositeBuffer(SharedBuffer(Result));
+ }
+ if (m_OpenFiles.size() == m_OpenFiles.capacity())
+ {
+ m_OpenFiles.pop_back();
+ }
+ m_OpenFiles.insert(
+ m_OpenFiles.begin(),
+ std::make_pair(
+ SequenceIndex,
+ std::make_unique<BufferedOpenFile>(LocalFilePath, m_OpenReadCount, m_CurrentOpenFileCount, m_ReadCount, m_ReadByteCount)));
+ CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size);
+ return Result;
+ }
+
+private:
+ const std::filesystem::path m_Path;
+ const ChunkedFolderContent& m_LocalContent;
+ const ChunkedContentLookup& m_LocalLookup;
+ std::vector<std::pair<uint32_t, std::unique_ptr<BufferedOpenFile>>> m_OpenFiles;
+ std::atomic<uint64_t>& m_OpenReadCount;
+ std::atomic<uint64_t>& m_CurrentOpenFileCount;
+ std::atomic<uint64_t>& m_ReadCount;
+ std::atomic<uint64_t>& m_ReadByteCount;
+};
+
bool
IsSingleFileChunk(const ChunkedFolderContent& RemoteContent,
const std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> Locations)
@@ -498,7 +656,8 @@ ZenTempFolderPath(const std::filesystem::path& ZenFolderPath)
////////////////////// BuildsOperationUpdateFolder
-BuildsOperationUpdateFolder::BuildsOperationUpdateFolder(OperationLogOutput& OperationLogOutput,
+BuildsOperationUpdateFolder::BuildsOperationUpdateFolder(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -513,7 +672,8 @@ BuildsOperationUpdateFolder::BuildsOperationUpdateFolder(OperationLogOutput&
const std::vector<ChunkBlockDescription>& BlockDescriptions,
const std::vector<IoHash>& LooseChunkHashes,
const Options& Options)
-: m_LogOutput(OperationLogOutput)
+: m_Log(Log)
+, m_Progress(Progress)
, m_Storage(Storage)
, m_AbortFlag(AbortFlag)
, m_PauseFlag(PauseFlag)
@@ -551,65 +711,30 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
};
auto EndProgress =
- MakeGuard([&]() { m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); });
-
- ZEN_ASSERT((!m_Options.PrimeCacheOnly) ||
- (m_Options.PrimeCacheOnly && (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off)));
+ MakeGuard([&]() { m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); });
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::ScanExistingData, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::ScanExistingData, (uint32_t)TaskSteps::StepCount);
CreateDirectories(m_CacheFolderPath);
CreateDirectories(m_TempDownloadFolderPath);
CreateDirectories(m_TempBlockFolderPath);
- Stopwatch CacheMappingTimer;
-
std::vector<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters(m_RemoteContent.ChunkedContent.SequenceRawHashes.size());
std::vector<bool> RemoteChunkIndexNeedsCopyFromLocalFileFlags(m_RemoteContent.ChunkedContent.ChunkHashes.size());
std::vector<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags(m_RemoteContent.ChunkedContent.ChunkHashes.size());
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> CachedChunkHashesFound;
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> CachedSequenceHashesFound;
- if (!m_Options.PrimeCacheOnly)
- {
- ScanCacheFolder(CachedChunkHashesFound, CachedSequenceHashesFound);
- }
+ ScanCacheFolder(CachedChunkHashesFound, CachedSequenceHashesFound);
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> CachedBlocksFound;
- if (!m_Options.PrimeCacheOnly)
- {
- ScanTempBlocksFolder(CachedBlocksFound);
- }
+ ScanTempBlocksFolder(CachedBlocksFound);
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> SequenceIndexesLeftToFindToRemoteIndex;
-
- if (!m_Options.PrimeCacheOnly && m_Options.EnableTargetFolderScavenging)
- {
- // Pick up all whole files we can use from current local state
- ZEN_TRACE_CPU("GetLocalSequences");
-
- Stopwatch LocalTimer;
-
- std::vector<uint32_t> MissingSequenceIndexes = ScanTargetFolder(CachedChunkHashesFound, CachedSequenceHashesFound);
-
- for (uint32_t RemoteSequenceIndex : MissingSequenceIndexes)
- {
- // We must write the sequence
- const uint32_t ChunkCount = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex];
- const IoHash& RemoteSequenceRawHash = m_RemoteContent.ChunkedContent.SequenceRawHashes[RemoteSequenceIndex];
- SequenceIndexChunksLeftToWriteCounters[RemoteSequenceIndex] = ChunkCount;
- SequenceIndexesLeftToFindToRemoteIndex.insert({RemoteSequenceRawHash, RemoteSequenceIndex});
- }
- }
- else
- {
- for (uint32_t RemoteSequenceIndex = 0; RemoteSequenceIndex < m_RemoteContent.ChunkedContent.SequenceRawHashes.size();
- RemoteSequenceIndex++)
- {
- const uint32_t ChunkCount = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex];
- SequenceIndexChunksLeftToWriteCounters[RemoteSequenceIndex] = ChunkCount;
- }
- }
+ InitializeSequenceCounters(SequenceIndexChunksLeftToWriteCounters,
+ SequenceIndexesLeftToFindToRemoteIndex,
+ CachedChunkHashesFound,
+ CachedSequenceHashesFound);
std::vector<ChunkedFolderContent> ScavengedContents;
std::vector<ChunkedContentLookup> ScavengedLookups;
@@ -618,7 +743,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
std::vector<ScavengedSequenceCopyOperation> ScavengedSequenceCopyOperations;
uint64_t ScavengedPathsCount = 0;
- if (!m_Options.PrimeCacheOnly && m_Options.EnableOtherDownloadsScavenging)
+ if (m_Options.EnableOtherDownloadsScavenging)
{
ZEN_TRACE_CPU("GetScavengedSequences");
@@ -627,123 +752,19 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
if (!SequenceIndexesLeftToFindToRemoteIndex.empty())
{
std::vector<ScavengeSource> ScavengeSources = FindScavengeSources();
-
- const size_t ScavengePathCount = ScavengeSources.size();
-
- ScavengedContents.resize(ScavengePathCount);
- ScavengedLookups.resize(ScavengePathCount);
- ScavengedPaths.resize(ScavengePathCount);
-
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Scavenging"));
- OperationLogOutput::ProgressBar& ScavengeProgressBar(*ProgressBarPtr);
-
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
-
- std::atomic<uint64_t> PathsFound(0);
- std::atomic<uint64_t> ChunksFound(0);
- std::atomic<uint64_t> PathsScavenged(0);
-
- for (size_t ScavengeIndex = 0; ScavengeIndex < ScavengePathCount; ScavengeIndex++)
- {
- Work.ScheduleWork(m_IOWorkerPool,
- [this,
- &ScavengeSources,
- &ScavengedContents,
- &ScavengedPaths,
- &ScavengedLookups,
- &PathsFound,
- &ChunksFound,
- &PathsScavenged,
- ScavengeIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_FindScavengeContent");
-
- const ScavengeSource& Source = ScavengeSources[ScavengeIndex];
- ChunkedFolderContent& ScavengedLocalContent = ScavengedContents[ScavengeIndex];
- ChunkedContentLookup& ScavengedLookup = ScavengedLookups[ScavengeIndex];
-
- if (FindScavengeContent(Source, ScavengedLocalContent, ScavengedLookup))
- {
- ScavengedPaths[ScavengeIndex] = Source.Path;
- PathsFound += ScavengedLocalContent.Paths.size();
- ChunksFound += ScavengedLocalContent.ChunkedContent.ChunkHashes.size();
- }
- else
- {
- ScavengedPaths[ScavengeIndex].clear();
- }
- PathsScavenged++;
- }
- });
- }
- {
- ZEN_TRACE_CPU("ScavengeScan_Wait");
-
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
- ZEN_UNUSED(PendingWork);
- std::string Details = fmt::format("{}/{} scanned. {} paths and {} chunks found for scavenging",
- PathsScavenged.load(),
- ScavengePathCount,
- PathsFound.load(),
- ChunksFound.load());
- ScavengeProgressBar.UpdateState(
- {.Task = "Scavenging ",
- .Details = Details,
- .TotalCount = ScavengePathCount,
- .RemainingCount = ScavengePathCount - PathsScavenged.load(),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
- });
- }
-
- ScavengeProgressBar.Finish();
+ ScanScavengeSources(ScavengeSources, ScavengedContents, ScavengedLookups, ScavengedPaths);
if (m_AbortFlag)
{
return;
}
- for (uint32_t ScavengedContentIndex = 0;
- ScavengedContentIndex < ScavengedContents.size() && (!SequenceIndexesLeftToFindToRemoteIndex.empty());
- ScavengedContentIndex++)
- {
- const std::filesystem::path& ScavengePath = ScavengedPaths[ScavengedContentIndex];
- if (!ScavengePath.empty())
- {
- const ChunkedFolderContent& ScavengedLocalContent = ScavengedContents[ScavengedContentIndex];
- const ChunkedContentLookup& ScavengedLookup = ScavengedLookups[ScavengedContentIndex];
-
- for (uint32_t ScavengedSequenceIndex = 0;
- ScavengedSequenceIndex < ScavengedLocalContent.ChunkedContent.SequenceRawHashes.size();
- ScavengedSequenceIndex++)
- {
- const IoHash& SequenceRawHash = ScavengedLocalContent.ChunkedContent.SequenceRawHashes[ScavengedSequenceIndex];
- if (auto It = SequenceIndexesLeftToFindToRemoteIndex.find(SequenceRawHash);
- It != SequenceIndexesLeftToFindToRemoteIndex.end())
- {
- const uint32_t RemoteSequenceIndex = It->second;
- const uint64_t RawSize =
- m_RemoteContent.RawSizes[m_RemoteLookup.SequenceIndexFirstPathIndex[RemoteSequenceIndex]];
- ZEN_ASSERT(RawSize > 0);
-
- const uint32_t ScavengedPathIndex = ScavengedLookup.SequenceIndexFirstPathIndex[ScavengedSequenceIndex];
- ZEN_ASSERT_SLOW(IsFile((ScavengePath / ScavengedLocalContent.Paths[ScavengedPathIndex]).make_preferred()));
-
- ScavengedSequenceCopyOperations.push_back({.ScavengedContentIndex = ScavengedContentIndex,
- .ScavengedPathIndex = ScavengedPathIndex,
- .RemoteSequenceIndex = RemoteSequenceIndex,
- .RawSize = RawSize});
-
- SequenceIndexesLeftToFindToRemoteIndex.erase(SequenceRawHash);
- SequenceIndexChunksLeftToWriteCounters[RemoteSequenceIndex] = 0;
-
- m_CacheMappingStats.ScavengedPathsMatchingSequencesCount++;
- m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount += RawSize;
- }
- }
- ScavengedPathsCount++;
- }
- }
+ MatchScavengedSequencesToRemote(ScavengedContents,
+ ScavengedLookups,
+ ScavengedPaths,
+ SequenceIndexesLeftToFindToRemoteIndex,
+ SequenceIndexChunksLeftToWriteCounters,
+ ScavengedSequenceCopyOperations,
+ ScavengedPathsCount);
}
m_CacheMappingStats.ScavengeElapsedWallTimeUs += ScavengeTimer.GetElapsedTimeUs();
}
@@ -762,7 +783,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
tsl::robin_map<IoHash, size_t, IoHash::Hasher> RawHashToCopyChunkDataIndex;
std::vector<CopyChunkData> CopyChunkDatas;
- if (!m_Options.PrimeCacheOnly && m_Options.EnableTargetFolderScavenging)
+ if (m_Options.EnableTargetFolderScavenging)
{
ZEN_TRACE_CPU("GetLocalChunks");
@@ -782,7 +803,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
m_CacheMappingStats.LocalScanElapsedWallTimeUs += LocalTimer.GetElapsedTimeUs();
}
- if (!m_Options.PrimeCacheOnly && m_Options.EnableOtherDownloadsScavenging)
+ if (m_Options.EnableOtherDownloadsScavenging)
{
ZEN_TRACE_CPU("GetScavengeChunks");
@@ -813,54 +834,40 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
if (m_CacheMappingStats.CacheSequenceHashesCount > 0 || m_CacheMappingStats.CacheChunkCount > 0 ||
m_CacheMappingStats.CacheBlockCount > 0)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Download cache: Found {} ({}) chunk sequences, {} ({}) chunks, {} ({}) blocks in {}",
- m_CacheMappingStats.CacheSequenceHashesCount,
- NiceBytes(m_CacheMappingStats.CacheSequenceHashesByteCount),
- m_CacheMappingStats.CacheChunkCount,
- NiceBytes(m_CacheMappingStats.CacheChunkByteCount),
- m_CacheMappingStats.CacheBlockCount,
- NiceBytes(m_CacheMappingStats.CacheBlocksByteCount),
- NiceTimeSpanMs(m_CacheMappingStats.CacheScanElapsedWallTimeUs / 1000));
+ ZEN_INFO("Download cache: Found {} ({}) chunk sequences, {} ({}) chunks, {} ({}) blocks in {}",
+ m_CacheMappingStats.CacheSequenceHashesCount,
+ NiceBytes(m_CacheMappingStats.CacheSequenceHashesByteCount),
+ m_CacheMappingStats.CacheChunkCount,
+ NiceBytes(m_CacheMappingStats.CacheChunkByteCount),
+ m_CacheMappingStats.CacheBlockCount,
+ NiceBytes(m_CacheMappingStats.CacheBlocksByteCount),
+ NiceTimeSpanMs(m_CacheMappingStats.CacheScanElapsedWallTimeUs / 1000));
}
if (m_CacheMappingStats.LocalPathsMatchingSequencesCount > 0 || m_CacheMappingStats.LocalChunkMatchingRemoteCount > 0)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Local state : Found {} ({}) chunk sequences, {} ({}) chunks in {}",
- m_CacheMappingStats.LocalPathsMatchingSequencesCount,
- NiceBytes(m_CacheMappingStats.LocalPathsMatchingSequencesByteCount),
- m_CacheMappingStats.LocalChunkMatchingRemoteCount,
- NiceBytes(m_CacheMappingStats.LocalChunkMatchingRemoteByteCount),
- NiceTimeSpanMs(m_CacheMappingStats.LocalScanElapsedWallTimeUs / 1000));
+ ZEN_INFO("Local state : Found {} ({}) chunk sequences, {} ({}) chunks in {}",
+ m_CacheMappingStats.LocalPathsMatchingSequencesCount,
+ NiceBytes(m_CacheMappingStats.LocalPathsMatchingSequencesByteCount),
+ m_CacheMappingStats.LocalChunkMatchingRemoteCount,
+ NiceBytes(m_CacheMappingStats.LocalChunkMatchingRemoteByteCount),
+ NiceTimeSpanMs(m_CacheMappingStats.LocalScanElapsedWallTimeUs / 1000));
}
if (m_CacheMappingStats.ScavengedPathsMatchingSequencesCount > 0 || m_CacheMappingStats.ScavengedChunkMatchingRemoteCount > 0)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Scavenge of {} paths, found {} ({}) chunk sequences, {} ({}) chunks in {}",
- ScavengedPathsCount,
- m_CacheMappingStats.ScavengedPathsMatchingSequencesCount,
- NiceBytes(m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount),
- m_CacheMappingStats.ScavengedChunkMatchingRemoteCount,
- NiceBytes(m_CacheMappingStats.ScavengedChunkMatchingRemoteByteCount),
- NiceTimeSpanMs(m_CacheMappingStats.ScavengeElapsedWallTimeUs / 1000));
+ ZEN_INFO("Scavenge of {} paths, found {} ({}) chunk sequences, {} ({}) chunks in {}",
+ ScavengedPathsCount,
+ m_CacheMappingStats.ScavengedPathsMatchingSequencesCount,
+ NiceBytes(m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount),
+ m_CacheMappingStats.ScavengedChunkMatchingRemoteCount,
+ NiceBytes(m_CacheMappingStats.ScavengedChunkMatchingRemoteByteCount),
+ NiceTimeSpanMs(m_CacheMappingStats.ScavengeElapsedWallTimeUs / 1000));
}
}
- uint64_t BytesToWrite = 0;
-
- for (uint32_t RemoteChunkIndex = 0; RemoteChunkIndex < m_RemoteContent.ChunkedContent.ChunkHashes.size(); RemoteChunkIndex++)
- {
- uint64_t ChunkWriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex);
- if (ChunkWriteCount > 0)
- {
- BytesToWrite += m_RemoteContent.ChunkedContent.ChunkRawSizes[RemoteChunkIndex] * ChunkWriteCount;
- if (!RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex])
- {
- RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex] = true;
- }
- }
- }
+ uint64_t BytesToWrite = CalculateBytesToWriteAndFlagNeededChunks(SequenceIndexChunksLeftToWriteCounters,
+ RemoteChunkIndexNeedsCopyFromLocalFileFlags,
+ RemoteChunkIndexNeedsCopyFromSourceFlags);
for (const ScavengedSequenceCopyOperation& ScavengeCopyOp : ScavengedSequenceCopyOperations)
{
@@ -885,7 +892,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
BlobsExistsResult ExistsResult;
{
ChunkBlockAnalyser BlockAnalyser(
- m_LogOutput,
+ Log(),
m_BlockDescriptions,
ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet,
.IsVerbose = m_Options.IsVerbose,
@@ -900,183 +907,21 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
std::vector<uint32_t> FetchBlockIndexes;
std::vector<uint32_t> CachedChunkBlockIndexes;
+ ClassifyCachedAndFetchBlocks(NeededBlocks, CachedBlocksFound, TotalPartWriteCount, CachedChunkBlockIndexes, FetchBlockIndexes);
- {
- ZEN_TRACE_CPU("BlockCacheFileExists");
- for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks)
- {
- if (m_Options.PrimeCacheOnly)
- {
- FetchBlockIndexes.push_back(NeededBlock.BlockIndex);
- }
- else
- {
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex];
- bool UsingCachedBlock = false;
- if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end())
- {
- TotalPartWriteCount++;
-
- std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
- if (IsFile(BlockPath))
- {
- CachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex);
- UsingCachedBlock = true;
- }
- }
- if (!UsingCachedBlock)
- {
- FetchBlockIndexes.push_back(NeededBlock.BlockIndex);
- }
- }
- }
- }
-
- std::vector<uint32_t> NeededLooseChunkIndexes;
-
- {
- NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size());
- for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++)
- {
- const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex];
- auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
- ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
- const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
-
- if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex])
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Skipping chunk {} due to cache reuse",
- m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]);
- }
- continue;
- }
-
- bool NeedsCopy = true;
- if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false))
- {
- uint64_t WriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex);
- if (WriteCount == 0)
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Skipping chunk {} due to cache reuse",
- m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]);
- }
- }
- else
- {
- NeededLooseChunkIndexes.push_back(LooseChunkIndex);
- }
- }
- }
- }
-
- if (m_Storage.CacheStorage)
- {
- ZEN_TRACE_CPU("BlobCacheExistCheck");
- Stopwatch Timer;
-
- std::vector<IoHash> BlobHashes;
- BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size());
+ std::vector<uint32_t> NeededLooseChunkIndexes = DetermineNeededLooseChunkIndexes(SequenceIndexChunksLeftToWriteCounters,
+ RemoteChunkIndexNeedsCopyFromLocalFileFlags,
+ RemoteChunkIndexNeedsCopyFromSourceFlags);
- for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes)
- {
- BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]);
- }
-
- for (uint32_t BlockIndex : FetchBlockIndexes)
- {
- BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash);
- }
-
- const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult =
- m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes);
-
- if (CacheExistsResult.size() == BlobHashes.size())
- {
- ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size());
- for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++)
- {
- if (CacheExistsResult[BlobIndex].HasBody)
- {
- ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]);
- }
- }
- }
- ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs();
- if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Remote cache : Found {} out of {} needed blobs in {}",
- ExistsResult.ExistingBlobs.size(),
- BlobHashes.size(),
- NiceTimeSpanMs(ExistsResult.ElapsedTimeMs));
- }
- }
-
- std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> BlockPartialDownloadModes;
-
- if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off)
- {
- BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off);
- }
- else
- {
- ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
- ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
-
- switch (m_Options.PartialBlockRequestMode)
- {
- case EPartialBlockRequestMode::Off:
- break;
- case EPartialBlockRequestMode::ZenCacheOnly:
- CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1
- ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
- : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
- CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
- break;
- case EPartialBlockRequestMode::Mixed:
- CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1
- ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
- : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
- CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
- break;
- case EPartialBlockRequestMode::All:
- CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1
- ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
- : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
- CloudPartialDownloadMode = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1
- ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange
- : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
- break;
- default:
- ZEN_ASSERT(false);
- break;
- }
-
- BlockPartialDownloadModes.reserve(m_BlockDescriptions.size());
- for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++)
- {
- const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash);
- BlockPartialDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode);
- }
- }
+ ExistsResult = QueryBlobCacheExists(NeededLooseChunkIndexes, FetchBlockIndexes);
+ std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> BlockPartialDownloadModes =
+ DeterminePartialDownloadModes(ExistsResult);
ZEN_ASSERT(BlockPartialDownloadModes.size() == m_BlockDescriptions.size());
ChunkBlockAnalyser::BlockResult PartialBlocks =
BlockAnalyser.CalculatePartialBlockDownloads(NeededBlocks, BlockPartialDownloadModes);
- struct LooseChunkHashWorkData
- {
- std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs;
- uint32_t RemoteChunkIndex = (uint32_t)-1;
- };
-
TotalRequestCount += NeededLooseChunkIndexes.size();
TotalPartWriteCount += NeededLooseChunkIndexes.size();
TotalRequestCount += PartialBlocks.BlockRanges.size();
@@ -1084,626 +929,52 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
TotalRequestCount += PartialBlocks.FullBlockIndexes.size();
TotalPartWriteCount += PartialBlocks.FullBlockIndexes.size();
- std::vector<LooseChunkHashWorkData> LooseChunkHashWorks;
- for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes)
- {
- const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex];
- auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
- ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
- const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
-
- std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs =
- GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex);
-
- ZEN_ASSERT(!ChunkTargetPtrs.empty());
- LooseChunkHashWorks.push_back(
- LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex});
- }
+ std::vector<LooseChunkHashWorkData> LooseChunkHashWorks =
+ BuildLooseChunkHashWorks(NeededLooseChunkIndexes, SequenceIndexChunksLeftToWriteCounters);
ZEN_TRACE_CPU("WriteChunks");
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount);
Stopwatch WriteTimer;
FilteredRate FilteredDownloadedBytesPerSecond;
FilteredRate FilteredWrittenBytesPerSecond;
- std::unique_ptr<OperationLogOutput::ProgressBar> WriteProgressBarPtr(
- m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing"));
- OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr);
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Writing");
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
TotalPartWriteCount += CopyChunkDatas.size();
TotalPartWriteCount += ScavengedSequenceCopyOperations.size();
BufferedWriteFileCache WriteCache;
- for (uint32_t ScavengeOpIndex = 0; ScavengeOpIndex < ScavengedSequenceCopyOperations.size(); ScavengeOpIndex++)
- {
- if (m_AbortFlag)
- {
- break;
- }
- if (!m_Options.PrimeCacheOnly)
- {
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &ScavengedPaths,
- &ScavengedSequenceCopyOperations,
- &ScavengedContents,
- &FilteredWrittenBytesPerSecond,
- ScavengeOpIndex,
- &WritePartsComplete,
- TotalPartWriteCount](std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_WriteScavenged");
-
- FilteredWrittenBytesPerSecond.Start();
-
- const ScavengedSequenceCopyOperation& ScavengeOp = ScavengedSequenceCopyOperations[ScavengeOpIndex];
- const ChunkedFolderContent& ScavengedContent = ScavengedContents[ScavengeOp.ScavengedContentIndex];
- const std::filesystem::path& ScavengeRootPath = ScavengedPaths[ScavengeOp.ScavengedContentIndex];
-
- WriteScavengedSequenceToCache(ScavengeRootPath, ScavengedContent, ScavengeOp);
-
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
- {
- FilteredWrittenBytesPerSecond.Stop();
- }
- }
- });
- }
- }
-
- for (uint32_t LooseChunkHashWorkIndex = 0; LooseChunkHashWorkIndex < LooseChunkHashWorks.size(); LooseChunkHashWorkIndex++)
- {
- if (m_AbortFlag)
- {
- break;
- }
-
- if (m_Options.PrimeCacheOnly)
- {
- const uint32_t RemoteChunkIndex = LooseChunkHashWorks[LooseChunkHashWorkIndex].RemoteChunkIndex;
- if (ExistsResult.ExistingBlobs.contains(m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]))
- {
- m_DownloadStats.RequestsCompleteCount++;
- continue;
- }
- }
-
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &SequenceIndexChunksLeftToWriteCounters,
- &Work,
- &ExistsResult,
- &WritePartsComplete,
- &LooseChunkHashWorks,
- LooseChunkHashWorkIndex,
- TotalRequestCount,
- TotalPartWriteCount,
- &WriteCache,
- &FilteredDownloadedBytesPerSecond,
- &FilteredWrittenBytesPerSecond](std::atomic<bool>&) mutable {
- ZEN_TRACE_CPU("Async_ReadPreDownloadedChunk");
- if (!m_AbortFlag)
- {
- LooseChunkHashWorkData& LooseChunkHashWork = LooseChunkHashWorks[LooseChunkHashWorkIndex];
- const uint32_t RemoteChunkIndex = LooseChunkHashWorks[LooseChunkHashWorkIndex].RemoteChunkIndex;
- WriteLooseChunk(RemoteChunkIndex,
- ExistsResult,
- SequenceIndexChunksLeftToWriteCounters,
- WritePartsComplete,
- std::move(LooseChunkHashWork.ChunkTargetPtrs),
- WriteCache,
- Work,
- TotalRequestCount,
- TotalPartWriteCount,
- FilteredDownloadedBytesPerSecond,
- FilteredWrittenBytesPerSecond);
- }
- },
- WorkerThreadPool::EMode::EnableBacklog);
- }
-
- std::unique_ptr<CloneQueryInterface> CloneQuery;
- if (m_Options.AllowFileClone)
- {
- CloneQuery = GetCloneQueryInterface(m_CacheFolderPath);
- }
-
- for (size_t CopyDataIndex = 0; CopyDataIndex < CopyChunkDatas.size(); CopyDataIndex++)
- {
- ZEN_ASSERT(!m_Options.PrimeCacheOnly);
- if (m_AbortFlag)
- {
- break;
- }
-
- Work.ScheduleWork(m_IOWorkerPool,
- [this,
- &CloneQuery,
- &SequenceIndexChunksLeftToWriteCounters,
- &WriteCache,
- &Work,
- &FilteredWrittenBytesPerSecond,
- &CopyChunkDatas,
- &ScavengedContents,
- &ScavengedLookups,
- &ScavengedPaths,
- &WritePartsComplete,
- TotalPartWriteCount,
- CopyDataIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_CopyLocal");
-
- FilteredWrittenBytesPerSecond.Start();
- const CopyChunkData& CopyData = CopyChunkDatas[CopyDataIndex];
-
- std::vector<uint32_t> WrittenSequenceIndexes = WriteLocalChunkToCache(CloneQuery.get(),
- CopyData,
- ScavengedContents,
- ScavengedLookups,
- ScavengedPaths,
- WriteCache);
- WritePartsComplete++;
- if (!m_AbortFlag)
- {
- if (WritePartsComplete == TotalPartWriteCount)
- {
- FilteredWrittenBytesPerSecond.Stop();
- }
-
- // Write tracking, updating this must be done without any files open
- std::vector<uint32_t> CompletedChunkSequences;
- for (uint32_t RemoteSequenceIndex : WrittenSequenceIndexes)
- {
- if (CompleteSequenceChunk(RemoteSequenceIndex, SequenceIndexChunksLeftToWriteCounters))
- {
- CompletedChunkSequences.push_back(RemoteSequenceIndex);
- }
- }
- WriteCache.Close(CompletedChunkSequences);
- VerifyAndCompleteChunkSequencesAsync(CompletedChunkSequences, Work);
- }
- }
- });
- }
-
- for (uint32_t BlockIndex : CachedChunkBlockIndexes)
- {
- ZEN_ASSERT(!m_Options.PrimeCacheOnly);
- if (m_AbortFlag)
- {
- break;
- }
-
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &RemoteChunkIndexNeedsCopyFromSourceFlags,
- &SequenceIndexChunksLeftToWriteCounters,
- &WriteCache,
- &Work,
- &FilteredWrittenBytesPerSecond,
- &WritePartsComplete,
- TotalPartWriteCount,
- BlockIndex](std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_WriteCachedBlock");
-
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
- FilteredWrittenBytesPerSecond.Start();
-
- std::filesystem::path BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
- IoBuffer BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
- if (!BlockBuffer)
- {
- throw std::runtime_error(
- fmt::format("Can not read block {} at {}", BlockDescription.BlockHash, BlockChunkPath));
- }
-
- if (!m_AbortFlag)
- {
- if (!WriteChunksBlockToCache(BlockDescription,
- SequenceIndexChunksLeftToWriteCounters,
- Work,
- CompositeBuffer(std::move(BlockBuffer)),
- RemoteChunkIndexNeedsCopyFromSourceFlags,
- WriteCache))
- {
- std::error_code DummyEc;
- RemoveFile(BlockChunkPath, DummyEc);
- throw std::runtime_error(fmt::format("Block {} is malformed", BlockDescription.BlockHash));
- }
+ WriteChunksContext Context{.Work = Work,
+ .WriteCache = WriteCache,
+ .SequenceIndexChunksLeftToWriteCounters = SequenceIndexChunksLeftToWriteCounters,
+ .RemoteChunkIndexNeedsCopyFromSourceFlags = RemoteChunkIndexNeedsCopyFromSourceFlags,
+ .WritePartsComplete = WritePartsComplete,
+ .TotalPartWriteCount = TotalPartWriteCount,
+ .TotalRequestCount = TotalRequestCount,
+ .ExistsResult = ExistsResult,
+ .FilteredDownloadedBytesPerSecond = FilteredDownloadedBytesPerSecond,
+ .FilteredWrittenBytesPerSecond = FilteredWrittenBytesPerSecond};
- std::error_code Ec = TryRemoveFile(BlockChunkPath);
- if (Ec)
- {
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- BlockChunkPath,
- Ec.value(),
- Ec.message());
- }
-
- WritePartsComplete++;
-
- if (WritePartsComplete == TotalPartWriteCount)
- {
- FilteredWrittenBytesPerSecond.Stop();
- }
- }
- }
- });
- }
-
- for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();)
- {
- ZEN_ASSERT(!m_Options.PrimeCacheOnly);
- if (m_AbortFlag)
- {
- break;
- }
-
- size_t RangeCount = 1;
- size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex;
- const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex];
- while (RangeCount < RangesLeft &&
- CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex)
- {
- RangeCount++;
- }
-
- Work.ScheduleWork(
- m_NetworkPool,
- [this,
- &RemoteChunkIndexNeedsCopyFromSourceFlags,
- &SequenceIndexChunksLeftToWriteCounters,
- &ExistsResult,
- &WriteCache,
- &FilteredDownloadedBytesPerSecond,
- TotalRequestCount,
- &WritePartsComplete,
- TotalPartWriteCount,
- &FilteredWrittenBytesPerSecond,
- &Work,
- &PartialBlocks,
- BlockRangeStartIndex = BlockRangeIndex,
- RangeCount = RangeCount](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_GetPartialBlockRanges");
-
- FilteredDownloadedBytesPerSecond.Start();
-
- DownloadPartialBlock(
- PartialBlocks.BlockRanges,
- BlockRangeStartIndex,
- RangeCount,
- ExistsResult,
- [this,
- &RemoteChunkIndexNeedsCopyFromSourceFlags,
- &SequenceIndexChunksLeftToWriteCounters,
- &WritePartsComplete,
- &WriteCache,
- &Work,
- TotalRequestCount,
- TotalPartWriteCount,
- &FilteredDownloadedBytesPerSecond,
- &FilteredWrittenBytesPerSecond,
- &PartialBlocks](IoBuffer&& InMemoryBuffer,
- const std::filesystem::path& OnDiskPath,
- size_t BlockRangeStartIndex,
- std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) {
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
-
- if (!m_AbortFlag)
- {
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &RemoteChunkIndexNeedsCopyFromSourceFlags,
- &SequenceIndexChunksLeftToWriteCounters,
- &WritePartsComplete,
- &WriteCache,
- &Work,
- TotalPartWriteCount,
- &FilteredWrittenBytesPerSecond,
- &PartialBlocks,
- BlockRangeStartIndex,
- BlockChunkPath = std::filesystem::path(OnDiskPath),
- BlockPartialBuffer = std::move(InMemoryBuffer),
- OffsetAndLengths = std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(),
- OffsetAndLengths.end())](
- std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_WritePartialBlock");
-
- const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex;
-
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
-
- if (BlockChunkPath.empty())
- {
- ZEN_ASSERT(BlockPartialBuffer);
- }
- else
- {
- ZEN_ASSERT(!BlockPartialBuffer);
- BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
- if (!BlockPartialBuffer)
- {
- throw std::runtime_error(
- fmt::format("Could not open downloaded block {} from {}",
- BlockDescription.BlockHash,
- BlockChunkPath));
- }
- }
-
- FilteredWrittenBytesPerSecond.Start();
-
- size_t RangeCount = OffsetAndLengths.size();
-
- for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++)
- {
- const std::pair<uint64_t, uint64_t>& OffsetAndLength =
- OffsetAndLengths[PartialRangeIndex];
- IoBuffer BlockRangeBuffer(BlockPartialBuffer,
- OffsetAndLength.first,
- OffsetAndLength.second);
-
- const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor =
- PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex];
-
- if (!WritePartialBlockChunksToCache(BlockDescription,
- SequenceIndexChunksLeftToWriteCounters,
- Work,
- CompositeBuffer(std::move(BlockRangeBuffer)),
- RangeDescriptor.ChunkBlockIndexStart,
- RangeDescriptor.ChunkBlockIndexStart +
- RangeDescriptor.ChunkBlockIndexCount - 1,
- RemoteChunkIndexNeedsCopyFromSourceFlags,
- WriteCache))
- {
- std::error_code DummyEc;
- RemoveFile(BlockChunkPath, DummyEc);
- throw std::runtime_error(
- fmt::format("Partial block {} is malformed", BlockDescription.BlockHash));
- }
-
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
- {
- FilteredWrittenBytesPerSecond.Stop();
- }
- }
- std::error_code Ec = TryRemoveFile(BlockChunkPath);
- if (Ec)
- {
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- BlockChunkPath,
- Ec.value(),
- Ec.message());
- }
- }
- },
- OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog
- : WorkerThreadPool::EMode::EnableBacklog);
- }
- });
- }
- });
- BlockRangeIndex += RangeCount;
- }
-
- for (uint32_t BlockIndex : PartialBlocks.FullBlockIndexes)
- {
- if (m_AbortFlag)
- {
- break;
- }
-
- if (m_Options.PrimeCacheOnly && ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash))
- {
- m_DownloadStats.RequestsCompleteCount++;
- continue;
- }
-
- Work.ScheduleWork(
- m_NetworkPool,
- [this,
- &WritePartsComplete,
- TotalPartWriteCount,
- &FilteredWrittenBytesPerSecond,
- &ExistsResult,
- &Work,
- &WriteCache,
- &RemoteChunkIndexNeedsCopyFromSourceFlags,
- &SequenceIndexChunksLeftToWriteCounters,
- &FilteredDownloadedBytesPerSecond,
- TotalRequestCount,
- BlockIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_GetFullBlock");
-
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
-
- FilteredDownloadedBytesPerSecond.Start();
-
- IoBuffer BlockBuffer;
- const bool ExistsInCache =
- m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash);
- if (ExistsInCache)
- {
- BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
- }
- if (!BlockBuffer)
- {
- BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
- if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- BlockDescription.BlockHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(BlockBuffer)));
- }
- }
- if (!BlockBuffer)
- {
- throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash));
- }
- if (!m_AbortFlag)
- {
- uint64_t BlockSize = BlockBuffer.GetSize();
- m_DownloadStats.DownloadedBlockCount++;
- m_DownloadStats.DownloadedBlockByteCount += BlockSize;
- m_DownloadStats.RequestsCompleteCount++;
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
+ ScheduleScavengedSequenceWrites(Context, ScavengedSequenceCopyOperations, ScavengedContents, ScavengedPaths);
+ ScheduleLooseChunkWrites(Context, LooseChunkHashWorks);
- if (!m_Options.PrimeCacheOnly)
- {
- std::filesystem::path BlockChunkPath;
+ std::unique_ptr<CloneQueryInterface> CloneQuery =
+ m_Options.AllowFileClone ? GetCloneQueryInterface(m_CacheFolderPath) : nullptr;
- // Check if the dowloaded block is file based and we can move it directly without rewriting it
- {
- IoBufferFileReference FileRef;
- if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) &&
- (FileRef.FileChunkSize == BlockSize))
- {
- ZEN_TRACE_CPU("MoveTempFullBlock");
- std::error_code Ec;
- std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec);
- if (!Ec)
- {
- BlockBuffer.SetDeleteOnClose(false);
- BlockBuffer = {};
- BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
- RenameFile(TempBlobPath, BlockChunkPath, Ec);
- if (Ec)
- {
- BlockChunkPath = std::filesystem::path{};
-
- // Re-open the temp file again
- BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete);
- BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true);
- BlockBuffer.SetDeleteOnClose(true);
- }
- }
- }
- }
-
- if (BlockChunkPath.empty() && (BlockSize > m_Options.MaximumInMemoryPayloadSize))
- {
- ZEN_TRACE_CPU("WriteTempFullBlock");
- // Could not be moved and rather large, lets store it on disk
- BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
- TemporaryFile::SafeWriteFile(BlockChunkPath, BlockBuffer);
- BlockBuffer = {};
- }
-
- if (!m_AbortFlag)
- {
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &Work,
- &RemoteChunkIndexNeedsCopyFromSourceFlags,
- &SequenceIndexChunksLeftToWriteCounters,
- BlockIndex,
- &WriteCache,
- &WritePartsComplete,
- TotalPartWriteCount,
- &FilteredWrittenBytesPerSecond,
- BlockChunkPath,
- BlockBuffer = std::move(BlockBuffer)](std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_WriteFullBlock");
-
- const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
-
- if (BlockChunkPath.empty())
- {
- ZEN_ASSERT(BlockBuffer);
- }
- else
- {
- ZEN_ASSERT(!BlockBuffer);
- BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
- if (!BlockBuffer)
- {
- throw std::runtime_error(
- fmt::format("Could not open dowloaded block {} from {}",
- BlockDescription.BlockHash,
- BlockChunkPath));
- }
- }
-
- FilteredWrittenBytesPerSecond.Start();
- if (!WriteChunksBlockToCache(BlockDescription,
- SequenceIndexChunksLeftToWriteCounters,
- Work,
- CompositeBuffer(std::move(BlockBuffer)),
- RemoteChunkIndexNeedsCopyFromSourceFlags,
- WriteCache))
- {
- std::error_code DummyEc;
- RemoveFile(BlockChunkPath, DummyEc);
- throw std::runtime_error(
- fmt::format("Block {} is malformed", BlockDescription.BlockHash));
- }
-
- if (!BlockChunkPath.empty())
- {
- std::error_code Ec = TryRemoveFile(BlockChunkPath);
- if (Ec)
- {
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- BlockChunkPath,
- Ec.value(),
- Ec.message());
- }
- }
-
- WritePartsComplete++;
-
- if (WritePartsComplete == TotalPartWriteCount)
- {
- FilteredWrittenBytesPerSecond.Stop();
- }
- }
- },
- BlockChunkPath.empty() ? WorkerThreadPool::EMode::DisableBacklog
- : WorkerThreadPool::EMode::EnableBacklog);
- }
- }
- }
- }
- });
- }
+ ScheduleLocalChunkCopies(Context, CopyChunkDatas, CloneQuery.get(), ScavengedContents, ScavengedLookups, ScavengedPaths);
+ ScheduleCachedBlockWrites(Context, CachedChunkBlockIndexes);
+ SchedulePartialBlockDownloads(Context, PartialBlocks);
+ ScheduleFullBlockDownloads(Context, PartialBlocks.FullBlockIndexes);
{
ZEN_TRACE_CPU("WriteChunks_Wait");
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
ZEN_UNUSED(PendingWork);
uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() +
m_DownloadStats.DownloadedBlockByteCount.load() +
@@ -1719,12 +990,11 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
{
CloneDetails = fmt::format(" ({} cloned)", NiceBytes(m_DiskStats.CloneByteCount.load()));
}
- std::string WriteDetails = m_Options.PrimeCacheOnly ? ""
- : fmt::format(" {}/{} ({}B/s) written{}",
- NiceBytes(m_WrittenChunkByteCount.load()),
- NiceBytes(BytesToWrite),
- NiceNum(FilteredWrittenBytesPerSecond.GetCurrent()),
- CloneDetails);
+ std::string WriteDetails = fmt::format(" {}/{} ({}B/s) written{}",
+ NiceBytes(m_WrittenChunkByteCount.load()),
+ NiceBytes(BytesToWrite),
+ NiceNum(FilteredWrittenBytesPerSecond.GetCurrent()),
+ CloneDetails);
std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}",
m_DownloadStats.RequestsCompleteCount.load(),
@@ -1734,11 +1004,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
WriteDetails);
std::string Task;
- if (m_Options.PrimeCacheOnly)
- {
- Task = "Downloading ";
- }
- else if ((m_WrittenChunkByteCount < BytesToWrite) || (BytesToValidate == 0))
+ if ((m_WrittenChunkByteCount < BytesToWrite) || (BytesToValidate == 0))
{
Task = "Writing chunks ";
}
@@ -1747,15 +1013,13 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
Task = "Verifying chunks ";
}
- WriteProgressBar.UpdateState(
- {.Task = Task,
- .Details = Details,
- .TotalCount = m_Options.PrimeCacheOnly ? TotalRequestCount : (BytesToWrite + BytesToValidate),
- .RemainingCount = m_Options.PrimeCacheOnly ? (TotalRequestCount - m_DownloadStats.RequestsCompleteCount.load())
- : ((BytesToWrite + BytesToValidate) -
- (m_WrittenChunkByteCount.load() + m_ValidatedChunkByteCount.load())),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ProgressBar->UpdateState({.Task = Task,
+ .Details = Details,
+ .TotalCount = (BytesToWrite + BytesToValidate),
+ .RemainingCount = ((BytesToWrite + BytesToValidate) -
+ (m_WrittenChunkByteCount.load() + m_ValidatedChunkByteCount.load())),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
});
}
@@ -1764,40 +1028,13 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
FilteredWrittenBytesPerSecond.Stop();
FilteredDownloadedBytesPerSecond.Stop();
- WriteProgressBar.Finish();
+ ProgressBar->Finish();
if (m_AbortFlag)
{
return;
}
- if (!m_Options.PrimeCacheOnly)
- {
- uint32_t RawSequencesMissingWriteCount = 0;
- for (uint32_t SequenceIndex = 0; SequenceIndex < SequenceIndexChunksLeftToWriteCounters.size(); SequenceIndex++)
- {
- const auto& SequenceIndexChunksLeftToWriteCounter = SequenceIndexChunksLeftToWriteCounters[SequenceIndex];
- if (SequenceIndexChunksLeftToWriteCounter.load() != 0)
- {
- RawSequencesMissingWriteCount++;
- const uint32_t PathIndex = m_RemoteLookup.SequenceIndexFirstPathIndex[SequenceIndex];
- const std::filesystem::path& IncompletePath = m_RemoteContent.Paths[PathIndex];
- ZEN_ASSERT(!IncompletePath.empty());
- const uint32_t ExpectedSequenceCount = m_RemoteContent.ChunkedContent.ChunkCounts[SequenceIndex];
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "{}: Max count {}, Current count {}",
- IncompletePath,
- ExpectedSequenceCount,
- SequenceIndexChunksLeftToWriteCounter.load());
- }
- ZEN_ASSERT(SequenceIndexChunksLeftToWriteCounter.load() <= ExpectedSequenceCount);
- }
- }
- ZEN_ASSERT(RawSequencesMissingWriteCount == 0);
- ZEN_ASSERT(m_WrittenChunkByteCount == BytesToWrite);
- ZEN_ASSERT(m_ValidatedChunkByteCount == BytesToValidate);
- }
+ VerifyWriteChunksComplete(SequenceIndexChunksLeftToWriteCounters, BytesToWrite, BytesToValidate);
const uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() +
m_DownloadStats.DownloadedBlockByteCount.load() +
@@ -1809,17 +1046,15 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
{
CloneDetails = fmt::format(" ({} cloned)", NiceBytes(m_DiskStats.CloneByteCount.load()));
}
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Downloaded {} ({}bits/s) in {}. Wrote {} ({}B/s){} in {}. Completed in {}",
- NiceBytes(DownloadedBytes),
- NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)),
- NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000),
- NiceBytes(m_WrittenChunkByteCount.load()),
- NiceNum(GetBytesPerSecond(FilteredWrittenBytesPerSecond.GetElapsedTimeUS(), m_DiskStats.WriteByteCount.load())),
- CloneDetails,
- NiceTimeSpanMs(FilteredWrittenBytesPerSecond.GetElapsedTimeUS() / 1000),
- NiceTimeSpanMs(WriteTimer.GetElapsedTimeMs()));
+ ZEN_INFO("Downloaded {} ({}bits/s) in {}. Wrote {} ({}B/s){} in {}. Completed in {}",
+ NiceBytes(DownloadedBytes),
+ NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)),
+ NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000),
+ NiceBytes(m_WrittenChunkByteCount.load()),
+ NiceNum(GetBytesPerSecond(FilteredWrittenBytesPerSecond.GetElapsedTimeUS(), m_DiskStats.WriteByteCount.load())),
+ CloneDetails,
+ NiceTimeSpanMs(FilteredWrittenBytesPerSecond.GetElapsedTimeUS() / 1000),
+ NiceTimeSpanMs(WriteTimer.GetElapsedTimeMs()));
}
m_WriteChunkStats.WriteChunksElapsedWallTimeUs = WriteTimer.GetElapsedTimeUs();
@@ -1827,199 +1062,40 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
m_WriteChunkStats.WriteTimeUs = FilteredWrittenBytesPerSecond.GetElapsedTimeUS();
}
- if (m_Options.PrimeCacheOnly)
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::PrepareTarget, (uint32_t)TaskSteps::StepCount);
+
+ if (m_AbortFlag)
{
return;
}
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::PrepareTarget, (uint32_t)TaskSteps::StepCount);
-
- tsl::robin_map<uint32_t, uint32_t> RemotePathIndexToLocalPathIndex;
- RemotePathIndexToLocalPathIndex.reserve(m_RemoteContent.Paths.size());
-
- tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> SequenceHashToLocalPathIndex;
- std::vector<uint32_t> RemoveLocalPathIndexes;
+ LocalPathCategorization Categorization = CategorizeLocalPaths(RemotePathToRemoteIndex);
if (m_AbortFlag)
{
return;
}
+ std::atomic<uint64_t> CachedCount = 0;
+ std::atomic<uint64_t> CachedByteCount = 0;
+ ScheduleLocalFileCaching(Categorization.FilesToCache, CachedCount, CachedByteCount);
+ if (m_AbortFlag)
{
- ZEN_TRACE_CPU("PrepareTarget");
-
- tsl::robin_set<IoHash, IoHash::Hasher> CachedRemoteSequences;
-
- std::vector<uint32_t> FilesToCache;
-
- uint64_t MatchCount = 0;
- uint64_t PathMismatchCount = 0;
- uint64_t HashMismatchCount = 0;
- std::atomic<uint64_t> CachedCount = 0;
- std::atomic<uint64_t> CachedByteCount = 0;
- uint64_t SkippedCount = 0;
- uint64_t DeleteCount = 0;
- for (uint32_t LocalPathIndex = 0; LocalPathIndex < m_LocalContent.Paths.size(); LocalPathIndex++)
- {
- if (m_AbortFlag)
- {
- break;
- }
- const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex];
- const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex];
-
- ZEN_ASSERT_SLOW(IsFile((m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred()));
-
- if (m_Options.EnableTargetFolderScavenging)
- {
- if (!m_Options.WipeTargetFolder)
- {
- // Check if it is already in the correct place
- if (auto RemotePathIt = RemotePathToRemoteIndex.find(LocalPath.generic_string());
- RemotePathIt != RemotePathToRemoteIndex.end())
- {
- const uint32_t RemotePathIndex = RemotePathIt->second;
- if (m_RemoteContent.RawHashes[RemotePathIndex] == RawHash)
- {
- // It is already in it's correct place
- RemotePathIndexToLocalPathIndex[RemotePathIndex] = LocalPathIndex;
- SequenceHashToLocalPathIndex.insert({RawHash, LocalPathIndex});
- MatchCount++;
- continue;
- }
- else
- {
- HashMismatchCount++;
- }
- }
- else
- {
- PathMismatchCount++;
- }
- }
-
- // Do we need it?
- if (m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash))
- {
- if (!CachedRemoteSequences.contains(RawHash))
- {
- // We need it, make sure we move it to the cache
- FilesToCache.push_back(LocalPathIndex);
- CachedRemoteSequences.insert(RawHash);
- continue;
- }
- else
- {
- SkippedCount++;
- }
- }
- }
-
- if (!m_Options.WipeTargetFolder)
- {
- // Explicitly delete the unneeded local file
- RemoveLocalPathIndexes.push_back(LocalPathIndex);
- DeleteCount++;
- }
- }
-
- if (m_AbortFlag)
- {
- return;
- }
-
- {
- ZEN_TRACE_CPU("CopyToCache");
-
- Stopwatch Timer;
-
- std::unique_ptr<OperationLogOutput::ProgressBar> CacheLocalProgressBarPtr(
- m_LogOutput.CreateProgressBar("Cache Local Data"));
- OperationLogOutput::ProgressBar& CacheLocalProgressBar(*CacheLocalProgressBarPtr);
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
-
- for (uint32_t LocalPathIndex : FilesToCache)
- {
- if (m_AbortFlag)
- {
- break;
- }
- Work.ScheduleWork(m_IOWorkerPool, [this, &CachedCount, &CachedByteCount, LocalPathIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_CopyToCache");
-
- const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex];
- const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex];
- const std::filesystem::path CacheFilePath = GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash);
- ZEN_ASSERT_SLOW(!IsFileWithRetry(CacheFilePath));
- const std::filesystem::path LocalFilePath = (m_Path / LocalPath).make_preferred();
-
- std::error_code Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath);
- if (Ec)
- {
- ZEN_OPERATION_LOG_WARN(m_LogOutput,
- "Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...",
- LocalFilePath,
- CacheFilePath,
- Ec.value(),
- Ec.message());
- Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath);
- if (Ec)
- {
- throw std::system_error(std::error_code(Ec.value(), std::system_category()),
- fmt::format("Failed to file from '{}' to '{}', reason: ({}) {}",
- LocalFilePath,
- CacheFilePath,
- Ec.value(),
- Ec.message()));
- }
- }
-
- CachedCount++;
- CachedByteCount += m_LocalContent.RawSizes[LocalPathIndex];
- }
- });
- }
-
- {
- ZEN_TRACE_CPU("CopyToCache_Wait");
-
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
- ZEN_UNUSED(PendingWork);
- const uint64_t WorkTotal = FilesToCache.size();
- const uint64_t WorkComplete = CachedCount.load();
- std::string Details = fmt::format("{}/{} ({}) files", WorkComplete, WorkTotal, NiceBytes(CachedByteCount));
- CacheLocalProgressBar.UpdateState(
- {.Task = "Caching local ",
- .Details = Details,
- .TotalCount = gsl::narrow<uint64_t>(WorkTotal),
- .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
- });
- }
-
- CacheLocalProgressBar.Finish();
- if (m_AbortFlag)
- {
- return;
- }
-
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Local state prep: Match: {}, PathMismatch: {}, HashMismatch: {}, Cached: {} ({}), Skipped: {}, "
- "Delete: {}",
- MatchCount,
- PathMismatchCount,
- HashMismatchCount,
- CachedCount.load(),
- NiceBytes(CachedByteCount.load()),
- SkippedCount,
- DeleteCount);
- }
+ return;
}
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::FinalizeTarget, (uint32_t)TaskSteps::StepCount);
+ ZEN_DEBUG(
+ "Local state prep: Match: {}, PathMismatch: {}, HashMismatch: {}, Cached: {} ({}), Skipped: {}, "
+ "Delete: {}",
+ Categorization.MatchCount,
+ Categorization.PathMismatchCount,
+ Categorization.HashMismatchCount,
+ CachedCount.load(),
+ NiceBytes(CachedByteCount.load()),
+ Categorization.SkippedCount,
+ Categorization.DeleteCount);
+
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::FinalizeTarget, (uint32_t)TaskSteps::StepCount);
if (m_Options.WipeTargetFolder)
{
@@ -2027,9 +1103,16 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
Stopwatch Timer;
// Clean target folder
- if (!CleanDirectory(m_LogOutput, m_IOWorkerPool, m_AbortFlag, m_PauseFlag, m_Options.IsQuiet, m_Path, m_Options.ExcludeFolders))
+ if (!CleanDirectory(Log(),
+ m_Progress,
+ m_IOWorkerPool,
+ m_AbortFlag,
+ m_PauseFlag,
+ m_Options.IsQuiet,
+ m_Path,
+ m_Options.ExcludeFolders))
{
- ZEN_OPERATION_LOG_WARN(m_LogOutput, "Some files in {} could not be removed", m_Path);
+ ZEN_WARN("Some files in {} could not be removed", m_Path);
}
m_RebuildFolderStateStats.CleanFolderElapsedWallTimeUs = Timer.GetElapsedTimeUs();
}
@@ -2044,315 +1127,49 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
Stopwatch Timer;
- std::unique_ptr<OperationLogOutput::ProgressBar> RebuildProgressBarPtr(m_LogOutput.CreateProgressBar("Rebuild State"));
- OperationLogOutput::ProgressBar& RebuildProgressBar(*RebuildProgressBarPtr);
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Rebuild State");
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
OutLocalFolderState.Paths.resize(m_RemoteContent.Paths.size());
OutLocalFolderState.RawSizes.resize(m_RemoteContent.Paths.size());
OutLocalFolderState.Attributes.resize(m_RemoteContent.Paths.size());
OutLocalFolderState.ModificationTicks.resize(m_RemoteContent.Paths.size());
- std::atomic<uint64_t> DeletedCount = 0;
-
- for (uint32_t LocalPathIndex : RemoveLocalPathIndexes)
- {
- if (m_AbortFlag)
- {
- break;
- }
- Work.ScheduleWork(m_IOWorkerPool, [this, &DeletedCount, LocalPathIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_RemoveFile");
-
- const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred();
- SetFileReadOnlyWithRetry(LocalFilePath, false);
- RemoveFileWithRetry(LocalFilePath);
- DeletedCount++;
- }
- });
- }
-
+ std::atomic<uint64_t> DeletedCount = 0;
std::atomic<uint64_t> TargetsComplete = 0;
- struct FinalizeTarget
- {
- IoHash RawHash;
- uint32_t RemotePathIndex;
- };
-
- std::vector<FinalizeTarget> Targets;
- Targets.reserve(m_RemoteContent.Paths.size());
- for (uint32_t RemotePathIndex = 0; RemotePathIndex < m_RemoteContent.Paths.size(); RemotePathIndex++)
- {
- Targets.push_back(
- FinalizeTarget{.RawHash = m_RemoteContent.RawHashes[RemotePathIndex], .RemotePathIndex = RemotePathIndex});
- }
- std::sort(Targets.begin(), Targets.end(), [](const FinalizeTarget& Lhs, const FinalizeTarget& Rhs) {
- if (Lhs.RawHash < Rhs.RawHash)
- {
- return true;
- }
- else if (Lhs.RawHash > Rhs.RawHash)
- {
- return false;
- }
- return Lhs.RemotePathIndex < Rhs.RemotePathIndex;
- });
-
- size_t TargetOffset = 0;
- while (TargetOffset < Targets.size())
- {
- if (m_AbortFlag)
- {
- break;
- }
-
- size_t TargetCount = 1;
- while ((TargetOffset + TargetCount) < Targets.size() &&
- (Targets[TargetOffset + TargetCount].RawHash == Targets[TargetOffset].RawHash))
- {
- TargetCount++;
- }
-
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &SequenceHashToLocalPathIndex,
- &Targets,
- &RemotePathIndexToLocalPathIndex,
- &OutLocalFolderState,
- BaseTargetOffset = TargetOffset,
- TargetCount,
- &TargetsComplete](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("Async_FinalizeChunkSequence");
+ ScheduleLocalFileRemovals(Work, Categorization.RemoveLocalPathIndexes, DeletedCount);
- size_t TargetOffset = BaseTargetOffset;
- const IoHash& RawHash = Targets[TargetOffset].RawHash;
+ std::vector<FinalizeTarget> Targets = BuildSortedFinalizeTargets();
- if (RawHash == IoHash::Zero)
- {
- ZEN_TRACE_CPU("CreateEmptyFiles");
- while (TargetOffset < (BaseTargetOffset + TargetCount))
- {
- const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex;
- ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash);
- const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex];
- std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred();
- if (!RemotePathIndexToLocalPathIndex[RemotePathIndex])
- {
- if (IsFileWithRetry(TargetFilePath))
- {
- SetFileReadOnlyWithRetry(TargetFilePath, false);
- }
- else
- {
- CreateDirectories(TargetFilePath.parent_path());
- }
- BasicFile OutputFile;
- OutputFile.Open(TargetFilePath, BasicFile::Mode::kTruncate);
- }
- OutLocalFolderState.Paths[RemotePathIndex] = TargetPath;
- OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex];
-
- OutLocalFolderState.Attributes[RemotePathIndex] =
- m_RemoteContent.Attributes.empty()
- ? GetNativeFileAttributes(TargetFilePath)
- : SetNativeFileAttributes(TargetFilePath,
- m_RemoteContent.Platform,
- m_RemoteContent.Attributes[RemotePathIndex]);
- OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath);
-
- TargetOffset++;
- TargetsComplete++;
- }
- }
- else
- {
- ZEN_TRACE_CPU("FinalizeFile");
- ZEN_ASSERT(m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash));
- const uint32_t FirstRemotePathIndex = Targets[TargetOffset].RemotePathIndex;
- const std::filesystem::path& FirstTargetPath = m_RemoteContent.Paths[FirstRemotePathIndex];
- std::filesystem::path FirstTargetFilePath = (m_Path / FirstTargetPath).make_preferred();
-
- if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(FirstRemotePathIndex);
- InPlaceIt != RemotePathIndexToLocalPathIndex.end())
- {
- ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath));
- }
- else
- {
- if (IsFileWithRetry(FirstTargetFilePath))
- {
- SetFileReadOnlyWithRetry(FirstTargetFilePath, false);
- }
- else
- {
- CreateDirectories(FirstTargetFilePath.parent_path());
- }
-
- if (auto InplaceIt = SequenceHashToLocalPathIndex.find(RawHash);
- InplaceIt != SequenceHashToLocalPathIndex.end())
- {
- ZEN_TRACE_CPU("Copy");
- const uint32_t LocalPathIndex = InplaceIt->second;
- const std::filesystem::path& SourcePath = m_LocalContent.Paths[LocalPathIndex];
- std::filesystem::path SourceFilePath = (m_Path / SourcePath).make_preferred();
- ZEN_ASSERT_SLOW(IsFileWithRetry(SourceFilePath));
-
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Copying from '{}' -> '{}'",
- SourceFilePath,
- FirstTargetFilePath);
- const uint64_t RawSize = m_LocalContent.RawSizes[LocalPathIndex];
- FastCopyFile(m_Options.AllowFileClone,
- m_Options.UseSparseFiles,
- SourceFilePath,
- FirstTargetFilePath,
- RawSize,
- m_DiskStats.WriteCount,
- m_DiskStats.WriteByteCount,
- m_DiskStats.CloneCount,
- m_DiskStats.CloneByteCount);
-
- m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++;
- }
- else
- {
- ZEN_TRACE_CPU("Rename");
- const std::filesystem::path CacheFilePath =
- GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash);
- ZEN_ASSERT_SLOW(IsFileWithRetry(CacheFilePath));
-
- std::error_code Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath);
- if (Ec)
- {
- ZEN_OPERATION_LOG_WARN(m_LogOutput,
- "Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...",
- CacheFilePath,
- FirstTargetFilePath,
- Ec.value(),
- Ec.message());
- Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath);
- if (Ec)
- {
- throw std::system_error(
- std::error_code(Ec.value(), std::system_category()),
- fmt::format("Failed to move file from '{}' to '{}', reason: ({}) {}",
- CacheFilePath,
- FirstTargetFilePath,
- Ec.value(),
- Ec.message()));
- }
- }
-
- m_RebuildFolderStateStats.FinalizeTreeFilesMovedCount++;
- }
- }
-
- OutLocalFolderState.Paths[FirstRemotePathIndex] = FirstTargetPath;
- OutLocalFolderState.RawSizes[FirstRemotePathIndex] = m_RemoteContent.RawSizes[FirstRemotePathIndex];
-
- OutLocalFolderState.Attributes[FirstRemotePathIndex] =
- m_RemoteContent.Attributes.empty()
- ? GetNativeFileAttributes(FirstTargetFilePath)
- : SetNativeFileAttributes(FirstTargetFilePath,
- m_RemoteContent.Platform,
- m_RemoteContent.Attributes[FirstRemotePathIndex]);
- OutLocalFolderState.ModificationTicks[FirstRemotePathIndex] =
- GetModificationTickFromPath(FirstTargetFilePath);
-
- TargetOffset++;
- TargetsComplete++;
-
- while (TargetOffset < (BaseTargetOffset + TargetCount))
- {
- const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex;
- ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash);
- const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex];
- std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred();
-
- if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(RemotePathIndex);
- InPlaceIt != RemotePathIndexToLocalPathIndex.end())
- {
- ZEN_ASSERT_SLOW(IsFileWithRetry(TargetFilePath));
- }
- else
- {
- ZEN_TRACE_CPU("Copy");
- if (IsFileWithRetry(TargetFilePath))
- {
- SetFileReadOnlyWithRetry(TargetFilePath, false);
- }
- else
- {
- CreateDirectories(TargetFilePath.parent_path());
- }
-
- ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath));
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Copying from '{}' -> '{}'",
- FirstTargetFilePath,
- TargetFilePath);
- const uint64_t RawSize = m_RemoteContent.RawSizes[RemotePathIndex];
- FastCopyFile(m_Options.AllowFileClone,
- m_Options.UseSparseFiles,
- FirstTargetFilePath,
- TargetFilePath,
- RawSize,
- m_DiskStats.WriteCount,
- m_DiskStats.WriteByteCount,
- m_DiskStats.CloneCount,
- m_DiskStats.CloneByteCount);
-
- m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++;
- }
-
- OutLocalFolderState.Paths[RemotePathIndex] = TargetPath;
- OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex];
-
- OutLocalFolderState.Attributes[RemotePathIndex] =
- m_RemoteContent.Attributes.empty()
- ? GetNativeFileAttributes(TargetFilePath)
- : SetNativeFileAttributes(TargetFilePath,
- m_RemoteContent.Platform,
- m_RemoteContent.Attributes[RemotePathIndex]);
- OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath);
-
- TargetOffset++;
- TargetsComplete++;
- }
- }
- }
- });
-
- TargetOffset += TargetCount;
- }
+ ScheduleTargetFinalization(Work,
+ Targets,
+ Categorization.SequenceHashToLocalPathIndex,
+ Categorization.RemotePathIndexToLocalPathIndex,
+ OutLocalFolderState,
+ TargetsComplete);
{
ZEN_TRACE_CPU("FinalizeTree_Wait");
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
ZEN_UNUSED(PendingWork);
- const uint64_t WorkTotal = Targets.size() + RemoveLocalPathIndexes.size();
+ const uint64_t WorkTotal = Targets.size() + Categorization.RemoveLocalPathIndexes.size();
const uint64_t WorkComplete = TargetsComplete.load() + DeletedCount.load();
std::string Details = fmt::format("{}/{} files", WorkComplete, WorkTotal);
- RebuildProgressBar.UpdateState({.Task = "Rebuilding state ",
- .Details = Details,
- .TotalCount = gsl::narrow<uint64_t>(WorkTotal),
- .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ProgressBar->UpdateState({.Task = "Rebuilding state ",
+ .Details = Details,
+ .TotalCount = gsl::narrow<uint64_t>(WorkTotal),
+ .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
});
}
m_RebuildFolderStateStats.FinalizeTreeElapsedWallTimeUs = Timer.GetElapsedTimeUs();
- RebuildProgressBar.Finish();
+ ProgressBar->Finish();
}
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount);
}
catch (const std::exception&)
{
@@ -2416,11 +1233,7 @@ BuildsOperationUpdateFolder::ScanCacheFolder(tsl::robin_map<IoHash, uint32_t, Io
std::error_code Ec = TryRemoveFile(CacheDirContent.Files[Index]);
if (Ec)
{
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- CacheDirContent.Files[Index],
- Ec.value(),
- Ec.message());
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", CacheDirContent.Files[Index], Ec.value(), Ec.message());
}
}
m_CacheMappingStats.CacheScanElapsedWallTimeUs += CacheTimer.GetElapsedTimeUs();
@@ -2476,17 +1289,1302 @@ BuildsOperationUpdateFolder::ScanTempBlocksFolder(tsl::robin_map<IoHash, uint32_
std::error_code Ec = TryRemoveFile(BlockDirContent.Files[Index]);
if (Ec)
{
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- BlockDirContent.Files[Index],
- Ec.value(),
- Ec.message());
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockDirContent.Files[Index], Ec.value(), Ec.message());
}
}
m_CacheMappingStats.CacheScanElapsedWallTimeUs += CacheTimer.GetElapsedTimeUs();
}
+void
+BuildsOperationUpdateFolder::InitializeSequenceCounters(std::vector<std::atomic<uint32_t>>& OutSequenceCounters,
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutSequencesLeftToFind,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedChunkHashesFound,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedSequenceHashesFound)
+{
+ if (m_Options.EnableTargetFolderScavenging)
+ {
+ // Pick up all whole files we can use from current local state
+ ZEN_TRACE_CPU("GetLocalSequences");
+
+ std::vector<uint32_t> MissingSequenceIndexes = ScanTargetFolder(CachedChunkHashesFound, CachedSequenceHashesFound);
+
+ for (uint32_t RemoteSequenceIndex : MissingSequenceIndexes)
+ {
+ // We must write the sequence
+ const uint32_t ChunkCount = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex];
+ const IoHash& RemoteSequenceRawHash = m_RemoteContent.ChunkedContent.SequenceRawHashes[RemoteSequenceIndex];
+ OutSequenceCounters[RemoteSequenceIndex] = ChunkCount;
+ OutSequencesLeftToFind.insert({RemoteSequenceRawHash, RemoteSequenceIndex});
+ }
+ }
+ else
+ {
+ for (uint32_t RemoteSequenceIndex = 0; RemoteSequenceIndex < m_RemoteContent.ChunkedContent.SequenceRawHashes.size();
+ RemoteSequenceIndex++)
+ {
+ OutSequenceCounters[RemoteSequenceIndex] = m_RemoteContent.ChunkedContent.ChunkCounts[RemoteSequenceIndex];
+ }
+ }
+}
+
+void
+BuildsOperationUpdateFolder::MatchScavengedSequencesToRemote(std::span<const ChunkedFolderContent> Contents,
+ std::span<const ChunkedContentLookup> Lookups,
+ std::span<const std::filesystem::path> Paths,
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& InOutSequencesLeftToFind,
+ std::vector<std::atomic<uint32_t>>& InOutSequenceCounters,
+ std::vector<ScavengedSequenceCopyOperation>& OutCopyOperations,
+ uint64_t& OutScavengedPathsCount)
+{
+ for (uint32_t ScavengedContentIndex = 0; ScavengedContentIndex < Contents.size() && !InOutSequencesLeftToFind.empty();
+ ScavengedContentIndex++)
+ {
+ const std::filesystem::path& ScavengePath = Paths[ScavengedContentIndex];
+ if (ScavengePath.empty())
+ {
+ continue;
+ }
+ const ChunkedFolderContent& ScavengedLocalContent = Contents[ScavengedContentIndex];
+ const ChunkedContentLookup& ScavengedLookup = Lookups[ScavengedContentIndex];
+
+ for (uint32_t ScavengedSequenceIndex = 0; ScavengedSequenceIndex < ScavengedLocalContent.ChunkedContent.SequenceRawHashes.size();
+ ScavengedSequenceIndex++)
+ {
+ const IoHash& SequenceRawHash = ScavengedLocalContent.ChunkedContent.SequenceRawHashes[ScavengedSequenceIndex];
+ auto It = InOutSequencesLeftToFind.find(SequenceRawHash);
+ if (It == InOutSequencesLeftToFind.end())
+ {
+ continue;
+ }
+ const uint32_t RemoteSequenceIndex = It->second;
+ const uint64_t RawSize = m_RemoteContent.RawSizes[m_RemoteLookup.SequenceIndexFirstPathIndex[RemoteSequenceIndex]];
+ ZEN_ASSERT(RawSize > 0);
+
+ const uint32_t ScavengedPathIndex = ScavengedLookup.SequenceIndexFirstPathIndex[ScavengedSequenceIndex];
+ ZEN_ASSERT_SLOW(IsFile((ScavengePath / ScavengedLocalContent.Paths[ScavengedPathIndex]).make_preferred()));
+
+ OutCopyOperations.push_back({.ScavengedContentIndex = ScavengedContentIndex,
+ .ScavengedPathIndex = ScavengedPathIndex,
+ .RemoteSequenceIndex = RemoteSequenceIndex,
+ .RawSize = RawSize});
+
+ InOutSequencesLeftToFind.erase(SequenceRawHash);
+ InOutSequenceCounters[RemoteSequenceIndex] = 0;
+
+ m_CacheMappingStats.ScavengedPathsMatchingSequencesCount++;
+ m_CacheMappingStats.ScavengedPathsMatchingSequencesByteCount += RawSize;
+ }
+ OutScavengedPathsCount++;
+ }
+}
+
+uint64_t
+BuildsOperationUpdateFolder::CalculateBytesToWriteAndFlagNeededChunks(std::span<const std::atomic<uint32_t>> SequenceCounters,
+ const std::vector<bool>& NeedsCopyFromLocalFileFlags,
+ std::span<std::atomic<bool>> OutNeedsCopyFromSourceFlags)
+{
+ uint64_t BytesToWrite = 0;
+ for (uint32_t RemoteChunkIndex = 0; RemoteChunkIndex < m_RemoteContent.ChunkedContent.ChunkHashes.size(); RemoteChunkIndex++)
+ {
+ const uint64_t ChunkWriteCount = GetChunkWriteCount(SequenceCounters, RemoteChunkIndex);
+ if (ChunkWriteCount > 0)
+ {
+ BytesToWrite += m_RemoteContent.ChunkedContent.ChunkRawSizes[RemoteChunkIndex] * ChunkWriteCount;
+ if (!NeedsCopyFromLocalFileFlags[RemoteChunkIndex])
+ {
+ OutNeedsCopyFromSourceFlags[RemoteChunkIndex] = true;
+ }
+ }
+ }
+ return BytesToWrite;
+}
+
+void
+BuildsOperationUpdateFolder::ClassifyCachedAndFetchBlocks(std::span<const ChunkBlockAnalyser::NeededBlock> NeededBlocks,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedBlocksFound,
+ uint64_t& TotalPartWriteCount,
+ std::vector<uint32_t>& OutCachedChunkBlockIndexes,
+ std::vector<uint32_t>& OutFetchBlockIndexes)
+{
+ ZEN_TRACE_CPU("BlockCacheFileExists");
+ for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks)
+ {
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex];
+ bool UsingCachedBlock = false;
+ if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end())
+ {
+ TotalPartWriteCount++;
+
+ std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
+ if (IsFile(BlockPath))
+ {
+ OutCachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex);
+ UsingCachedBlock = true;
+ }
+ }
+ if (!UsingCachedBlock)
+ {
+ OutFetchBlockIndexes.push_back(NeededBlock.BlockIndex);
+ }
+ }
+}
+
+std::vector<uint32_t>
+BuildsOperationUpdateFolder::DetermineNeededLooseChunkIndexes(std::span<const std::atomic<uint32_t>> SequenceCounters,
+ const std::vector<bool>& NeedsCopyFromLocalFileFlags,
+ std::span<std::atomic<bool>> NeedsCopyFromSourceFlags)
+{
+ std::vector<uint32_t> NeededLooseChunkIndexes;
+ NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size());
+ for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++)
+ {
+ const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex];
+ auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
+ ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
+ const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
+
+ if (NeedsCopyFromLocalFileFlags[RemoteChunkIndex])
+ {
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Skipping chunk {} due to cache reuse", m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]);
+ }
+ continue;
+ }
+
+ bool NeedsCopy = true;
+ if (NeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false))
+ {
+ const uint64_t WriteCount = GetChunkWriteCount(SequenceCounters, RemoteChunkIndex);
+ if (WriteCount == 0)
+ {
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Skipping chunk {} due to cache reuse", m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]);
+ }
+ }
+ else
+ {
+ NeededLooseChunkIndexes.push_back(LooseChunkIndex);
+ }
+ }
+ }
+ return NeededLooseChunkIndexes;
+}
+
+BuildsOperationUpdateFolder::BlobsExistsResult
+BuildsOperationUpdateFolder::QueryBlobCacheExists(std::span<const uint32_t> NeededLooseChunkIndexes,
+ std::span<const uint32_t> FetchBlockIndexes)
+{
+ BlobsExistsResult Result;
+ if (!m_Storage.CacheStorage)
+ {
+ return Result;
+ }
+
+ ZEN_TRACE_CPU("BlobCacheExistCheck");
+ Stopwatch Timer;
+
+ std::vector<IoHash> BlobHashes;
+ BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size());
+
+ for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes)
+ {
+ BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]);
+ }
+
+ for (uint32_t BlockIndex : FetchBlockIndexes)
+ {
+ BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash);
+ }
+
+ const std::vector<BuildStorageCache::BlobExistsResult> CacheExistsResult = m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes);
+
+ if (CacheExistsResult.size() == BlobHashes.size())
+ {
+ Result.ExistingBlobs.reserve(CacheExistsResult.size());
+ for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++)
+ {
+ if (CacheExistsResult[BlobIndex].HasBody)
+ {
+ Result.ExistingBlobs.insert(BlobHashes[BlobIndex]);
+ }
+ }
+ }
+ Result.ElapsedTimeMs = Timer.GetElapsedTimeMs();
+ if (!Result.ExistingBlobs.empty() && !m_Options.IsQuiet)
+ {
+ ZEN_INFO("Remote cache : Found {} out of {} needed blobs in {}",
+ Result.ExistingBlobs.size(),
+ BlobHashes.size(),
+ NiceTimeSpanMs(Result.ElapsedTimeMs));
+ }
+ return Result;
+}
+
+std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode>
+BuildsOperationUpdateFolder::DeterminePartialDownloadModes(const BlobsExistsResult& ExistsResult)
+{
+ std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> Modes;
+
+ if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off)
+ {
+ Modes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off);
+ return Modes;
+ }
+
+ const bool MultiRangeCache = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1;
+ const bool MultiRangeBuild = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1;
+ ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode =
+ MultiRangeCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange;
+ ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+
+ switch (m_Options.PartialBlockRequestMode)
+ {
+ case EPartialBlockRequestMode::Off:
+ break;
+ case EPartialBlockRequestMode::ZenCacheOnly:
+ CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off;
+ break;
+ case EPartialBlockRequestMode::Mixed:
+ CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
+ break;
+ case EPartialBlockRequestMode::All:
+ CloudPartialDownloadMode = MultiRangeBuild ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange
+ : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange;
+ break;
+ default:
+ ZEN_ASSERT(false);
+ break;
+ }
+
+ Modes.reserve(m_BlockDescriptions.size());
+ for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++)
+ {
+ const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash);
+ Modes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode);
+ }
+ return Modes;
+}
+
+std::vector<BuildsOperationUpdateFolder::LooseChunkHashWorkData>
+BuildsOperationUpdateFolder::BuildLooseChunkHashWorks(std::span<const uint32_t> NeededLooseChunkIndexes,
+ std::span<const std::atomic<uint32_t>> SequenceCounters)
+{
+ std::vector<LooseChunkHashWorkData> LooseChunkHashWorks;
+ LooseChunkHashWorks.reserve(NeededLooseChunkIndexes.size());
+ for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes)
+ {
+ const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex];
+ auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash);
+ ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end());
+ const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second;
+
+ std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs =
+ GetRemainingChunkTargets(SequenceCounters, RemoteChunkIndex);
+
+ ZEN_ASSERT(!ChunkTargetPtrs.empty());
+ LooseChunkHashWorks.push_back(LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex});
+ }
+ return LooseChunkHashWorks;
+}
+
+void
+BuildsOperationUpdateFolder::VerifyWriteChunksComplete(std::span<const std::atomic<uint32_t>> SequenceCounters,
+ uint64_t BytesToWrite,
+ uint64_t BytesToValidate)
+{
+ uint32_t RawSequencesMissingWriteCount = 0;
+ for (uint32_t SequenceIndex = 0; SequenceIndex < SequenceCounters.size(); SequenceIndex++)
+ {
+ const auto& Counter = SequenceCounters[SequenceIndex];
+ if (Counter.load() != 0)
+ {
+ RawSequencesMissingWriteCount++;
+ const uint32_t PathIndex = m_RemoteLookup.SequenceIndexFirstPathIndex[SequenceIndex];
+ const std::filesystem::path& IncompletePath = m_RemoteContent.Paths[PathIndex];
+ ZEN_ASSERT(!IncompletePath.empty());
+ const uint32_t ExpectedSequenceCount = m_RemoteContent.ChunkedContent.ChunkCounts[SequenceIndex];
+ if (!m_Options.IsQuiet)
+ {
+ ZEN_INFO("{}: Max count {}, Current count {}", IncompletePath, ExpectedSequenceCount, Counter.load());
+ }
+ ZEN_ASSERT(Counter.load() <= ExpectedSequenceCount);
+ }
+ }
+ ZEN_ASSERT(RawSequencesMissingWriteCount == 0);
+ ZEN_ASSERT(m_WrittenChunkByteCount == BytesToWrite);
+ ZEN_ASSERT(m_ValidatedChunkByteCount == BytesToValidate);
+}
+
+std::vector<BuildsOperationUpdateFolder::FinalizeTarget>
+BuildsOperationUpdateFolder::BuildSortedFinalizeTargets()
+{
+ std::vector<FinalizeTarget> Targets;
+ Targets.reserve(m_RemoteContent.Paths.size());
+ for (uint32_t RemotePathIndex = 0; RemotePathIndex < m_RemoteContent.Paths.size(); RemotePathIndex++)
+ {
+ Targets.push_back(FinalizeTarget{.RawHash = m_RemoteContent.RawHashes[RemotePathIndex], .RemotePathIndex = RemotePathIndex});
+ }
+ std::sort(Targets.begin(), Targets.end(), [](const FinalizeTarget& Lhs, const FinalizeTarget& Rhs) {
+ return std::tie(Lhs.RawHash, Lhs.RemotePathIndex) < std::tie(Rhs.RawHash, Rhs.RemotePathIndex);
+ });
+ return Targets;
+}
+
+void
+BuildsOperationUpdateFolder::ScanScavengeSources(std::span<const ScavengeSource> Sources,
+ std::vector<ChunkedFolderContent>& OutContents,
+ std::vector<ChunkedContentLookup>& OutLookups,
+ std::vector<std::filesystem::path>& OutPaths)
+{
+ ZEN_TRACE_CPU("ScanScavengeSources");
+
+ const size_t ScavengePathCount = Sources.size();
+ OutContents.resize(ScavengePathCount);
+ OutLookups.resize(ScavengePathCount);
+ OutPaths.resize(ScavengePathCount);
+
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Scavenging");
+
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+
+ std::atomic<uint64_t> PathsFound(0);
+ std::atomic<uint64_t> ChunksFound(0);
+ std::atomic<uint64_t> PathsScavenged(0);
+
+ for (size_t ScavengeIndex = 0; ScavengeIndex < ScavengePathCount; ScavengeIndex++)
+ {
+ Work.ScheduleWork(m_IOWorkerPool,
+ [this, &Sources, &OutContents, &OutPaths, &OutLookups, &PathsFound, &ChunksFound, &PathsScavenged, ScavengeIndex](
+ std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_FindScavengeContent");
+
+ const ScavengeSource& Source = Sources[ScavengeIndex];
+ ChunkedFolderContent& ScavengedLocalContent = OutContents[ScavengeIndex];
+ ChunkedContentLookup& ScavengedLookup = OutLookups[ScavengeIndex];
+
+ if (FindScavengeContent(Source, ScavengedLocalContent, ScavengedLookup))
+ {
+ OutPaths[ScavengeIndex] = Source.Path;
+ PathsFound += ScavengedLocalContent.Paths.size();
+ ChunksFound += ScavengedLocalContent.ChunkedContent.ChunkHashes.size();
+ }
+ else
+ {
+ OutPaths[ScavengeIndex].clear();
+ }
+ PathsScavenged++;
+ }
+ });
+ }
+ {
+ ZEN_TRACE_CPU("ScavengeScan_Wait");
+
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ ZEN_UNUSED(PendingWork);
+ std::string Details = fmt::format("{}/{} scanned. {} paths and {} chunks found for scavenging",
+ PathsScavenged.load(),
+ ScavengePathCount,
+ PathsFound.load(),
+ ChunksFound.load());
+ ProgressBar->UpdateState({.Task = "Scavenging ",
+ .Details = Details,
+ .TotalCount = ScavengePathCount,
+ .RemainingCount = ScavengePathCount - PathsScavenged.load(),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
+ });
+ }
+
+ ProgressBar->Finish();
+}
+
+BuildsOperationUpdateFolder::LocalPathCategorization
+BuildsOperationUpdateFolder::CategorizeLocalPaths(const tsl::robin_map<std::string, uint32_t>& RemotePathToRemoteIndex)
+{
+ ZEN_TRACE_CPU("PrepareTarget");
+
+ LocalPathCategorization Result;
+ tsl::robin_set<IoHash, IoHash::Hasher> CachedRemoteSequences;
+
+ Result.RemotePathIndexToLocalPathIndex.reserve(m_RemoteContent.Paths.size());
+
+ for (uint32_t LocalPathIndex = 0; LocalPathIndex < m_LocalContent.Paths.size(); LocalPathIndex++)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+ const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex];
+ const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex];
+
+ ZEN_ASSERT_SLOW(IsFile((m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred()));
+
+ if (m_Options.EnableTargetFolderScavenging)
+ {
+ if (!m_Options.WipeTargetFolder)
+ {
+ // Check if it is already in the correct place
+ if (auto RemotePathIt = RemotePathToRemoteIndex.find(LocalPath.generic_string());
+ RemotePathIt != RemotePathToRemoteIndex.end())
+ {
+ const uint32_t RemotePathIndex = RemotePathIt->second;
+ if (m_RemoteContent.RawHashes[RemotePathIndex] == RawHash)
+ {
+ // It is already in it's correct place
+ Result.RemotePathIndexToLocalPathIndex[RemotePathIndex] = LocalPathIndex;
+ Result.SequenceHashToLocalPathIndex.insert({RawHash, LocalPathIndex});
+ Result.MatchCount++;
+ continue;
+ }
+ else
+ {
+ Result.HashMismatchCount++;
+ }
+ }
+ else
+ {
+ Result.PathMismatchCount++;
+ }
+ }
+
+ // Do we need it?
+ if (m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash))
+ {
+ if (!CachedRemoteSequences.contains(RawHash))
+ {
+ // We need it, make sure we move it to the cache
+ Result.FilesToCache.push_back(LocalPathIndex);
+ CachedRemoteSequences.insert(RawHash);
+ continue;
+ }
+ else
+ {
+ Result.SkippedCount++;
+ }
+ }
+ }
+
+ if (!m_Options.WipeTargetFolder)
+ {
+ // Explicitly delete the unneeded local file
+ Result.RemoveLocalPathIndexes.push_back(LocalPathIndex);
+ Result.DeleteCount++;
+ }
+ }
+
+ return Result;
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleLocalFileCaching(std::span<const uint32_t> FilesToCache,
+ std::atomic<uint64_t>& OutCachedCount,
+ std::atomic<uint64_t>& OutCachedByteCount)
+{
+ ZEN_TRACE_CPU("CopyToCache");
+
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Cache Local Data");
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+
+ for (uint32_t LocalPathIndex : FilesToCache)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+ Work.ScheduleWork(m_IOWorkerPool, [this, &OutCachedCount, &OutCachedByteCount, LocalPathIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_CopyToCache");
+
+ const IoHash& RawHash = m_LocalContent.RawHashes[LocalPathIndex];
+ const std::filesystem::path& LocalPath = m_LocalContent.Paths[LocalPathIndex];
+ const std::filesystem::path CacheFilePath = GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash);
+ ZEN_ASSERT_SLOW(!IsFileWithRetry(CacheFilePath));
+ const std::filesystem::path LocalFilePath = (m_Path / LocalPath).make_preferred();
+
+ std::error_code Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...",
+ LocalFilePath,
+ CacheFilePath,
+ Ec.value(),
+ Ec.message());
+ Ec = RenameFileWithRetry(LocalFilePath, CacheFilePath);
+ if (Ec)
+ {
+ throw std::system_error(std::error_code(Ec.value(), std::system_category()),
+ fmt::format("Failed to file from '{}' to '{}', reason: ({}) {}",
+ LocalFilePath,
+ CacheFilePath,
+ Ec.value(),
+ Ec.message()));
+ }
+ }
+
+ OutCachedCount++;
+ OutCachedByteCount += m_LocalContent.RawSizes[LocalPathIndex];
+ }
+ });
+ }
+
+ {
+ ZEN_TRACE_CPU("CopyToCache_Wait");
+
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ ZEN_UNUSED(PendingWork);
+ const uint64_t WorkTotal = FilesToCache.size();
+ const uint64_t WorkComplete = OutCachedCount.load();
+ std::string Details = fmt::format("{}/{} ({}) files", WorkComplete, WorkTotal, NiceBytes(OutCachedByteCount));
+ ProgressBar->UpdateState({.Task = "Caching local ",
+ .Details = Details,
+ .TotalCount = gsl::narrow<uint64_t>(WorkTotal),
+ .RemainingCount = gsl::narrow<uint64_t>(WorkTotal - WorkComplete),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
+ });
+ }
+
+ ProgressBar->Finish();
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleScavengedSequenceWrites(WriteChunksContext& Context,
+ std::span<const ScavengedSequenceCopyOperation> CopyOperations,
+ const std::vector<ChunkedFolderContent>& ScavengedContents,
+ const std::vector<std::filesystem::path>& ScavengedPaths)
+{
+ for (uint32_t ScavengeOpIndex = 0; ScavengeOpIndex < CopyOperations.size(); ScavengeOpIndex++)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+ Context.Work.ScheduleWork(
+ m_IOWorkerPool,
+ [this, &Context, &CopyOperations, &ScavengedContents, &ScavengedPaths, ScavengeOpIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_WriteScavenged");
+
+ Context.FilteredWrittenBytesPerSecond.Start();
+
+ const ScavengedSequenceCopyOperation& ScavengeOp = CopyOperations[ScavengeOpIndex];
+ const ChunkedFolderContent& ScavengedContent = ScavengedContents[ScavengeOp.ScavengedContentIndex];
+ const std::filesystem::path& ScavengeRootPath = ScavengedPaths[ScavengeOp.ScavengedContentIndex];
+
+ WriteScavengedSequenceToCache(ScavengeRootPath, ScavengedContent, ScavengeOp);
+
+ if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount)
+ {
+ Context.FilteredWrittenBytesPerSecond.Stop();
+ }
+ }
+ });
+ }
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleLooseChunkWrites(WriteChunksContext& Context, std::vector<LooseChunkHashWorkData>& LooseChunkHashWorks)
+{
+ for (uint32_t LooseChunkHashWorkIndex = 0; LooseChunkHashWorkIndex < LooseChunkHashWorks.size(); LooseChunkHashWorkIndex++)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ Context.Work.ScheduleWork(
+ m_IOWorkerPool,
+ [this, &Context, &LooseChunkHashWorks, LooseChunkHashWorkIndex](std::atomic<bool>&) {
+ ZEN_TRACE_CPU("Async_ReadPreDownloadedChunk");
+ if (!m_AbortFlag)
+ {
+ LooseChunkHashWorkData& LooseChunkHashWork = LooseChunkHashWorks[LooseChunkHashWorkIndex];
+ const uint32_t RemoteChunkIndex = LooseChunkHashWork.RemoteChunkIndex;
+ WriteLooseChunk(RemoteChunkIndex,
+ Context.ExistsResult,
+ Context.SequenceIndexChunksLeftToWriteCounters,
+ Context.WritePartsComplete,
+ std::move(LooseChunkHashWork.ChunkTargetPtrs),
+ Context.WriteCache,
+ Context.Work,
+ Context.TotalRequestCount,
+ Context.TotalPartWriteCount,
+ Context.FilteredDownloadedBytesPerSecond,
+ Context.FilteredWrittenBytesPerSecond);
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+ }
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleLocalChunkCopies(WriteChunksContext& Context,
+ std::span<const CopyChunkData> CopyChunkDatas,
+ CloneQueryInterface* CloneQuery,
+ const std::vector<ChunkedFolderContent>& ScavengedContents,
+ const std::vector<ChunkedContentLookup>& ScavengedLookups,
+ const std::vector<std::filesystem::path>& ScavengedPaths)
+{
+ for (size_t CopyDataIndex = 0; CopyDataIndex < CopyChunkDatas.size(); CopyDataIndex++)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ Context.Work.ScheduleWork(
+ m_IOWorkerPool,
+ [this, &Context, CloneQuery, &CopyChunkDatas, &ScavengedContents, &ScavengedLookups, &ScavengedPaths, CopyDataIndex](
+ std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_CopyLocal");
+
+ Context.FilteredWrittenBytesPerSecond.Start();
+ const CopyChunkData& CopyData = CopyChunkDatas[CopyDataIndex];
+
+ std::vector<uint32_t> WrittenSequenceIndexes = WriteLocalChunkToCache(CloneQuery,
+ CopyData,
+ ScavengedContents,
+ ScavengedLookups,
+ ScavengedPaths,
+ Context.WriteCache);
+ bool WritePartsDone = Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount;
+ if (!m_AbortFlag)
+ {
+ if (WritePartsDone)
+ {
+ Context.FilteredWrittenBytesPerSecond.Stop();
+ }
+
+ // Write tracking, updating this must be done without any files open
+ std::vector<uint32_t> CompletedChunkSequences;
+ for (uint32_t RemoteSequenceIndex : WrittenSequenceIndexes)
+ {
+ if (CompleteSequenceChunk(RemoteSequenceIndex, Context.SequenceIndexChunksLeftToWriteCounters))
+ {
+ CompletedChunkSequences.push_back(RemoteSequenceIndex);
+ }
+ }
+ Context.WriteCache.Close(CompletedChunkSequences);
+ VerifyAndCompleteChunkSequencesAsync(CompletedChunkSequences, Context.Work);
+ }
+ }
+ });
+ }
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleCachedBlockWrites(WriteChunksContext& Context, std::span<const uint32_t> CachedBlockIndexes)
+{
+ for (uint32_t BlockIndex : CachedBlockIndexes)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ Context.Work.ScheduleWork(m_IOWorkerPool, [this, &Context, BlockIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_WriteCachedBlock");
+
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+ Context.FilteredWrittenBytesPerSecond.Start();
+
+ std::filesystem::path BlockChunkPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString();
+ IoBuffer BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
+ if (!BlockBuffer)
+ {
+ throw std::runtime_error(fmt::format("Can not read block {} at {}", BlockDescription.BlockHash, BlockChunkPath));
+ }
+
+ if (!m_AbortFlag)
+ {
+ if (!WriteChunksBlockToCache(BlockDescription,
+ Context.SequenceIndexChunksLeftToWriteCounters,
+ Context.Work,
+ CompositeBuffer(std::move(BlockBuffer)),
+ Context.RemoteChunkIndexNeedsCopyFromSourceFlags,
+ Context.WriteCache))
+ {
+ std::error_code DummyEc;
+ RemoveFile(BlockChunkPath, DummyEc);
+ throw std::runtime_error(fmt::format("Block {} is malformed", BlockDescription.BlockHash));
+ }
+
+ std::error_code Ec = TryRemoveFile(BlockChunkPath);
+ if (Ec)
+ {
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockChunkPath, Ec.value(), Ec.message());
+ }
+
+ if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount)
+ {
+ Context.FilteredWrittenBytesPerSecond.Stop();
+ }
+ }
+ }
+ });
+ }
+}
+
+void
+BuildsOperationUpdateFolder::SchedulePartialBlockDownloads(WriteChunksContext& Context,
+ const ChunkBlockAnalyser::BlockResult& PartialBlocks)
+{
+ for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ size_t RangeCount = 1;
+ size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex;
+ const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex];
+ while (RangeCount < RangesLeft &&
+ CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex)
+ {
+ RangeCount++;
+ }
+
+ Context.Work.ScheduleWork(
+ m_NetworkPool,
+ [this, &Context, &PartialBlocks, BlockRangeStartIndex = BlockRangeIndex, RangeCount = RangeCount](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_GetPartialBlockRanges");
+
+ Context.FilteredDownloadedBytesPerSecond.Start();
+
+ DownloadPartialBlock(
+ PartialBlocks.BlockRanges,
+ BlockRangeStartIndex,
+ RangeCount,
+ Context.ExistsResult,
+ Context.TotalRequestCount,
+ Context.FilteredDownloadedBytesPerSecond,
+ [this, &Context, &PartialBlocks](IoBuffer&& InMemoryBuffer,
+ const std::filesystem::path& OnDiskPath,
+ size_t BlockRangeStartIndex,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths) {
+ if (!m_AbortFlag)
+ {
+ Context.Work.ScheduleWork(
+ m_IOWorkerPool,
+ [this,
+ &Context,
+ &PartialBlocks,
+ BlockRangeStartIndex,
+ BlockChunkPath = std::filesystem::path(OnDiskPath),
+ BlockPartialBuffer = std::move(InMemoryBuffer),
+ OffsetAndLengths =
+ std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), OffsetAndLengths.end())](
+ std::atomic<bool>&) mutable {
+ if (!m_AbortFlag)
+ {
+ WritePartialBlockToCache(Context,
+ BlockRangeStartIndex,
+ std::move(BlockPartialBuffer),
+ BlockChunkPath,
+ OffsetAndLengths,
+ PartialBlocks);
+ }
+ },
+ OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog : WorkerThreadPool::EMode::EnableBacklog);
+ }
+ });
+ }
+ });
+ BlockRangeIndex += RangeCount;
+ }
+}
+
+void
+BuildsOperationUpdateFolder::WritePartialBlockToCache(WriteChunksContext& Context,
+ size_t BlockRangeStartIndex,
+ IoBuffer BlockPartialBuffer,
+ const std::filesystem::path& BlockChunkPath,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths,
+ const ChunkBlockAnalyser::BlockResult& PartialBlocks)
+{
+ ZEN_TRACE_CPU("Async_WritePartialBlock");
+
+ const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex;
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+
+ if (BlockChunkPath.empty())
+ {
+ ZEN_ASSERT(BlockPartialBuffer);
+ }
+ else
+ {
+ ZEN_ASSERT(!BlockPartialBuffer);
+ BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
+ if (!BlockPartialBuffer)
+ {
+ throw std::runtime_error(fmt::format("Could not open downloaded block {} from {}", BlockDescription.BlockHash, BlockChunkPath));
+ }
+ }
+
+ Context.FilteredWrittenBytesPerSecond.Start();
+
+ const size_t RangeCount = OffsetAndLengths.size();
+
+ for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++)
+ {
+ const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengths[PartialRangeIndex];
+ IoBuffer BlockRangeBuffer(BlockPartialBuffer, OffsetAndLength.first, OffsetAndLength.second);
+
+ const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor =
+ PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex];
+
+ if (!WritePartialBlockChunksToCache(BlockDescription,
+ Context.SequenceIndexChunksLeftToWriteCounters,
+ Context.Work,
+ CompositeBuffer(std::move(BlockRangeBuffer)),
+ RangeDescriptor.ChunkBlockIndexStart,
+ RangeDescriptor.ChunkBlockIndexStart + RangeDescriptor.ChunkBlockIndexCount - 1,
+ Context.RemoteChunkIndexNeedsCopyFromSourceFlags,
+ Context.WriteCache))
+ {
+ std::error_code DummyEc;
+ RemoveFile(BlockChunkPath, DummyEc);
+ throw std::runtime_error(fmt::format("Partial block {} is malformed", BlockDescription.BlockHash));
+ }
+
+ if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount)
+ {
+ Context.FilteredWrittenBytesPerSecond.Stop();
+ }
+ }
+ std::error_code Ec = TryRemoveFile(BlockChunkPath);
+ if (Ec)
+ {
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockChunkPath, Ec.value(), Ec.message());
+ }
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleFullBlockDownloads(WriteChunksContext& Context, std::span<const uint32_t> FullBlockIndexes)
+{
+ for (uint32_t BlockIndex : FullBlockIndexes)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ Context.Work.ScheduleWork(m_NetworkPool, [this, &Context, BlockIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_GetFullBlock");
+
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+
+ Context.FilteredDownloadedBytesPerSecond.Start();
+
+ IoBuffer BlockBuffer;
+ const bool ExistsInCache =
+ m_Storage.CacheStorage && Context.ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash);
+ if (ExistsInCache)
+ {
+ BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
+ }
+ if (!BlockBuffer)
+ {
+ try
+ {
+ BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
+ }
+ if (!m_AbortFlag)
+ {
+ if (!BlockBuffer)
+ {
+ throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash));
+ }
+
+ uint64_t BlockSize = BlockBuffer.GetSize();
+ m_DownloadStats.DownloadedBlockCount++;
+ m_DownloadStats.DownloadedBlockByteCount += BlockSize;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == Context.TotalRequestCount)
+ {
+ Context.FilteredDownloadedBytesPerSecond.Stop();
+ }
+
+ const bool PutInCache = !ExistsInCache && m_Storage.CacheStorage && m_Options.PopulateCache;
+
+ std::filesystem::path BlockChunkPath =
+ TryMoveDownloadedChunk(BlockBuffer,
+ m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(),
+ /* ForceDiskBased */ PutInCache || (BlockSize > m_Options.MaximumInMemoryPayloadSize));
+
+ if (PutInCache)
+ {
+ ZEN_ASSERT(!BlockChunkPath.empty());
+ IoBuffer CacheBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
+ if (CacheBuffer)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlockDescription.BlockHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(CacheBuffer)));
+ }
+ }
+
+ if (!m_AbortFlag)
+ {
+ Context.Work.ScheduleWork(
+ m_IOWorkerPool,
+ [this, &Context, BlockIndex, BlockChunkPath, BlockBuffer = std::move(BlockBuffer)](std::atomic<bool>&) mutable {
+ if (!m_AbortFlag)
+ {
+ WriteFullBlockToCache(Context, BlockIndex, std::move(BlockBuffer), BlockChunkPath);
+ }
+ },
+ BlockChunkPath.empty() ? WorkerThreadPool::EMode::DisableBacklog : WorkerThreadPool::EMode::EnableBacklog);
+ }
+ }
+ }
+ });
+ }
+}
+
+void
+BuildsOperationUpdateFolder::WriteFullBlockToCache(WriteChunksContext& Context,
+ uint32_t BlockIndex,
+ IoBuffer BlockBuffer,
+ const std::filesystem::path& BlockChunkPath)
+{
+ ZEN_TRACE_CPU("Async_WriteFullBlock");
+
+ const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex];
+
+ if (BlockChunkPath.empty())
+ {
+ ZEN_ASSERT(BlockBuffer);
+ }
+ else
+ {
+ ZEN_ASSERT(!BlockBuffer);
+ BlockBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath);
+ if (!BlockBuffer)
+ {
+ throw std::runtime_error(fmt::format("Could not open dowloaded block {} from {}", BlockDescription.BlockHash, BlockChunkPath));
+ }
+ }
+
+ Context.FilteredWrittenBytesPerSecond.Start();
+ if (!WriteChunksBlockToCache(BlockDescription,
+ Context.SequenceIndexChunksLeftToWriteCounters,
+ Context.Work,
+ CompositeBuffer(std::move(BlockBuffer)),
+ Context.RemoteChunkIndexNeedsCopyFromSourceFlags,
+ Context.WriteCache))
+ {
+ std::error_code DummyEc;
+ RemoveFile(BlockChunkPath, DummyEc);
+ throw std::runtime_error(fmt::format("Block {} is malformed", BlockDescription.BlockHash));
+ }
+
+ if (!BlockChunkPath.empty())
+ {
+ std::error_code Ec = TryRemoveFile(BlockChunkPath);
+ if (Ec)
+ {
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", BlockChunkPath, Ec.value(), Ec.message());
+ }
+ }
+
+ if (Context.WritePartsComplete.fetch_add(1) + 1 == Context.TotalPartWriteCount)
+ {
+ Context.FilteredWrittenBytesPerSecond.Stop();
+ }
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleLocalFileRemovals(ParallelWork& Work,
+ std::span<const uint32_t> RemoveLocalPathIndexes,
+ std::atomic<uint64_t>& DeletedCount)
+{
+ for (uint32_t LocalPathIndex : RemoveLocalPathIndexes)
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+ Work.ScheduleWork(m_IOWorkerPool, [this, &DeletedCount, LocalPathIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("Async_RemoveFile");
+
+ const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred();
+ SetFileReadOnlyWithRetry(LocalFilePath, false);
+ RemoveFileWithRetry(LocalFilePath);
+ DeletedCount++;
+ }
+ });
+ }
+}
+
+void
+BuildsOperationUpdateFolder::ScheduleTargetFinalization(
+ ParallelWork& Work,
+ std::span<const FinalizeTarget> Targets,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex,
+ const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex,
+ FolderContent& OutLocalFolderState,
+ std::atomic<uint64_t>& TargetsComplete)
+{
+ size_t TargetOffset = 0;
+ while (TargetOffset < Targets.size())
+ {
+ if (m_AbortFlag)
+ {
+ break;
+ }
+
+ size_t TargetCount = 1;
+ while ((TargetOffset + TargetCount) < Targets.size() &&
+ (Targets[TargetOffset + TargetCount].RawHash == Targets[TargetOffset].RawHash))
+ {
+ TargetCount++;
+ }
+
+ Work.ScheduleWork(m_IOWorkerPool,
+ [this,
+ &SequenceHashToLocalPathIndex,
+ Targets,
+ &RemotePathIndexToLocalPathIndex,
+ &OutLocalFolderState,
+ BaseTargetOffset = TargetOffset,
+ TargetCount,
+ &TargetsComplete](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ FinalizeTargetGroup(BaseTargetOffset,
+ TargetCount,
+ Targets,
+ SequenceHashToLocalPathIndex,
+ RemotePathIndexToLocalPathIndex,
+ OutLocalFolderState,
+ TargetsComplete);
+ }
+ });
+
+ TargetOffset += TargetCount;
+ }
+}
+
+void
+BuildsOperationUpdateFolder::FinalizeTargetGroup(size_t BaseOffset,
+ size_t Count,
+ std::span<const FinalizeTarget> Targets,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex,
+ const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex,
+ FolderContent& OutLocalFolderState,
+ std::atomic<uint64_t>& TargetsComplete)
+{
+ ZEN_TRACE_CPU("Async_FinalizeChunkSequence");
+
+ size_t TargetOffset = BaseOffset;
+ const IoHash& RawHash = Targets[TargetOffset].RawHash;
+
+ if (RawHash == IoHash::Zero)
+ {
+ ZEN_TRACE_CPU("CreateEmptyFiles");
+ while (TargetOffset < (BaseOffset + Count))
+ {
+ const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex;
+ ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash);
+ const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex];
+ std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred();
+ auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(RemotePathIndex);
+ if (InPlaceIt == RemotePathIndexToLocalPathIndex.end() || InPlaceIt->second == 0)
+ {
+ if (IsFileWithRetry(TargetFilePath))
+ {
+ SetFileReadOnlyWithRetry(TargetFilePath, false);
+ }
+ else
+ {
+ CreateDirectories(TargetFilePath.parent_path());
+ }
+ BasicFile OutputFile;
+ OutputFile.Open(TargetFilePath, BasicFile::Mode::kTruncate);
+ }
+ OutLocalFolderState.Paths[RemotePathIndex] = TargetPath;
+ OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex];
+
+ OutLocalFolderState.Attributes[RemotePathIndex] =
+ m_RemoteContent.Attributes.empty()
+ ? GetNativeFileAttributes(TargetFilePath)
+ : SetNativeFileAttributes(TargetFilePath, m_RemoteContent.Platform, m_RemoteContent.Attributes[RemotePathIndex]);
+ OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath);
+
+ TargetOffset++;
+ TargetsComplete++;
+ }
+ }
+ else
+ {
+ ZEN_TRACE_CPU("FinalizeFile");
+ ZEN_ASSERT(m_RemoteLookup.RawHashToSequenceIndex.contains(RawHash));
+ const uint32_t FirstRemotePathIndex = Targets[TargetOffset].RemotePathIndex;
+ const std::filesystem::path& FirstTargetPath = m_RemoteContent.Paths[FirstRemotePathIndex];
+ std::filesystem::path FirstTargetFilePath = (m_Path / FirstTargetPath).make_preferred();
+
+ if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(FirstRemotePathIndex); InPlaceIt != RemotePathIndexToLocalPathIndex.end())
+ {
+ ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath));
+ }
+ else
+ {
+ if (IsFileWithRetry(FirstTargetFilePath))
+ {
+ SetFileReadOnlyWithRetry(FirstTargetFilePath, false);
+ }
+ else
+ {
+ CreateDirectories(FirstTargetFilePath.parent_path());
+ }
+
+ if (auto InplaceIt = SequenceHashToLocalPathIndex.find(RawHash); InplaceIt != SequenceHashToLocalPathIndex.end())
+ {
+ ZEN_TRACE_CPU("Copy");
+ const uint32_t LocalPathIndex = InplaceIt->second;
+ const std::filesystem::path& SourcePath = m_LocalContent.Paths[LocalPathIndex];
+ std::filesystem::path SourceFilePath = (m_Path / SourcePath).make_preferred();
+ ZEN_ASSERT_SLOW(IsFileWithRetry(SourceFilePath));
+
+ ZEN_DEBUG("Copying from '{}' -> '{}'", SourceFilePath, FirstTargetFilePath);
+ const uint64_t RawSize = m_LocalContent.RawSizes[LocalPathIndex];
+ FastCopyFile(m_Options.AllowFileClone,
+ m_Options.UseSparseFiles,
+ SourceFilePath,
+ FirstTargetFilePath,
+ RawSize,
+ m_DiskStats.WriteCount,
+ m_DiskStats.WriteByteCount,
+ m_DiskStats.CloneCount,
+ m_DiskStats.CloneByteCount);
+
+ m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++;
+ }
+ else
+ {
+ ZEN_TRACE_CPU("Rename");
+ const std::filesystem::path CacheFilePath = GetFinalChunkedSequenceFileName(m_CacheFolderPath, RawHash);
+ ZEN_ASSERT_SLOW(IsFileWithRetry(CacheFilePath));
+
+ std::error_code Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to move file from '{}' to '{}', reason: ({}) {}, retrying...",
+ CacheFilePath,
+ FirstTargetFilePath,
+ Ec.value(),
+ Ec.message());
+ Ec = RenameFileWithRetry(CacheFilePath, FirstTargetFilePath);
+ if (Ec)
+ {
+ throw std::system_error(std::error_code(Ec.value(), std::system_category()),
+ fmt::format("Failed to move file from '{}' to '{}', reason: ({}) {}",
+ CacheFilePath,
+ FirstTargetFilePath,
+ Ec.value(),
+ Ec.message()));
+ }
+ }
+
+ m_RebuildFolderStateStats.FinalizeTreeFilesMovedCount++;
+ }
+ }
+
+ OutLocalFolderState.Paths[FirstRemotePathIndex] = FirstTargetPath;
+ OutLocalFolderState.RawSizes[FirstRemotePathIndex] = m_RemoteContent.RawSizes[FirstRemotePathIndex];
+
+ OutLocalFolderState.Attributes[FirstRemotePathIndex] =
+ m_RemoteContent.Attributes.empty()
+ ? GetNativeFileAttributes(FirstTargetFilePath)
+ : SetNativeFileAttributes(FirstTargetFilePath, m_RemoteContent.Platform, m_RemoteContent.Attributes[FirstRemotePathIndex]);
+ OutLocalFolderState.ModificationTicks[FirstRemotePathIndex] = GetModificationTickFromPath(FirstTargetFilePath);
+
+ TargetOffset++;
+ TargetsComplete++;
+
+ while (TargetOffset < (BaseOffset + Count))
+ {
+ const uint32_t RemotePathIndex = Targets[TargetOffset].RemotePathIndex;
+ ZEN_ASSERT(Targets[TargetOffset].RawHash == RawHash);
+ const std::filesystem::path& TargetPath = m_RemoteContent.Paths[RemotePathIndex];
+ std::filesystem::path TargetFilePath = (m_Path / TargetPath).make_preferred();
+
+ if (auto InPlaceIt = RemotePathIndexToLocalPathIndex.find(RemotePathIndex); InPlaceIt != RemotePathIndexToLocalPathIndex.end())
+ {
+ ZEN_ASSERT_SLOW(IsFileWithRetry(TargetFilePath));
+ }
+ else
+ {
+ ZEN_TRACE_CPU("Copy");
+ if (IsFileWithRetry(TargetFilePath))
+ {
+ SetFileReadOnlyWithRetry(TargetFilePath, false);
+ }
+ else
+ {
+ CreateDirectories(TargetFilePath.parent_path());
+ }
+
+ ZEN_ASSERT_SLOW(IsFileWithRetry(FirstTargetFilePath));
+ ZEN_DEBUG("Copying from '{}' -> '{}'", FirstTargetFilePath, TargetFilePath);
+ const uint64_t RawSize = m_RemoteContent.RawSizes[RemotePathIndex];
+ FastCopyFile(m_Options.AllowFileClone,
+ m_Options.UseSparseFiles,
+ FirstTargetFilePath,
+ TargetFilePath,
+ RawSize,
+ m_DiskStats.WriteCount,
+ m_DiskStats.WriteByteCount,
+ m_DiskStats.CloneCount,
+ m_DiskStats.CloneByteCount);
+
+ m_RebuildFolderStateStats.FinalizeTreeFilesCopiedCount++;
+ }
+
+ OutLocalFolderState.Paths[RemotePathIndex] = TargetPath;
+ OutLocalFolderState.RawSizes[RemotePathIndex] = m_RemoteContent.RawSizes[RemotePathIndex];
+
+ OutLocalFolderState.Attributes[RemotePathIndex] =
+ m_RemoteContent.Attributes.empty()
+ ? GetNativeFileAttributes(TargetFilePath)
+ : SetNativeFileAttributes(TargetFilePath, m_RemoteContent.Platform, m_RemoteContent.Attributes[RemotePathIndex]);
+ OutLocalFolderState.ModificationTicks[RemotePathIndex] = GetModificationTickFromPath(TargetFilePath);
+
+ TargetOffset++;
+ TargetsComplete++;
+ }
+ }
+}
+
std::vector<BuildsOperationUpdateFolder::ScavengeSource>
BuildsOperationUpdateFolder::FindScavengeSources()
{
@@ -2526,7 +2624,7 @@ BuildsOperationUpdateFolder::FindScavengeSources()
}
catch (const std::exception& Ex)
{
- ZEN_OPERATION_LOG_WARN(m_LogOutput, "{}", Ex.what());
+ ZEN_WARN("{}", Ex.what());
DeleteEntry = true;
}
@@ -2563,11 +2661,7 @@ BuildsOperationUpdateFolder::ScanTargetFolder(const tsl::robin_map<IoHash, uint3
ZEN_ASSERT_SLOW(IsFile(CacheFilePath));
if (m_Options.IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Found sequence {} at {} ({})",
- RemoteSequenceRawHash,
- CacheFilePath,
- NiceBytes(RemoteRawSize));
+ ZEN_INFO("Found sequence {} at {} ({})", RemoteSequenceRawHash, CacheFilePath, NiceBytes(RemoteRawSize));
}
}
else if (auto CacheChunkIt = CachedChunkHashesFound.find(RemoteSequenceRawHash); CacheChunkIt != CachedChunkHashesFound.end())
@@ -2576,11 +2670,7 @@ BuildsOperationUpdateFolder::ScanTargetFolder(const tsl::robin_map<IoHash, uint3
ZEN_ASSERT_SLOW(IsFile(CacheFilePath));
if (m_Options.IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Found chunk {} at {} ({})",
- RemoteSequenceRawHash,
- CacheFilePath,
- NiceBytes(RemoteRawSize));
+ ZEN_INFO("Found chunk {} at {} ({})", RemoteSequenceRawHash, CacheFilePath, NiceBytes(RemoteRawSize));
}
}
else if (auto It = m_LocalLookup.RawHashToSequenceIndex.find(RemoteSequenceRawHash);
@@ -2594,11 +2684,7 @@ BuildsOperationUpdateFolder::ScanTargetFolder(const tsl::robin_map<IoHash, uint3
m_CacheMappingStats.LocalPathsMatchingSequencesByteCount += RemoteRawSize;
if (m_Options.IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Found sequence {} at {} ({})",
- RemoteSequenceRawHash,
- LocalFilePath,
- NiceBytes(RemoteRawSize));
+ ZEN_INFO("Found sequence {} at {} ({})", RemoteSequenceRawHash, LocalFilePath, NiceBytes(RemoteRawSize));
}
}
else
@@ -2624,10 +2710,9 @@ BuildsOperationUpdateFolder::FindScavengeContent(const ScavengeSource& Source,
BuildSaveState SavedState = ReadBuildSaveStateFile(Source.StateFilePath);
if (SavedState.Version == BuildSaveState::NoVersion)
{
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Skipping old build state at '{}', state files before version {} can not be trusted during scavenge",
- Source.StateFilePath,
- BuildSaveState::kVersion1);
+ ZEN_DEBUG("Skipping old build state at '{}', state files before version {} can not be trusted during scavenge",
+ Source.StateFilePath,
+ BuildSaveState::kVersion1);
return false;
}
OutScavengedLocalContent = std::move(SavedState.State.ChunkedContent);
@@ -2635,7 +2720,7 @@ BuildsOperationUpdateFolder::FindScavengeContent(const ScavengeSource& Source,
}
catch (const std::exception& Ex)
{
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Skipping invalid build state at '{}', reason: {}", Source.StateFilePath, Ex.what());
+ ZEN_DEBUG("Skipping invalid build state at '{}', reason: {}", Source.StateFilePath, Ex.what());
return false;
}
@@ -2688,11 +2773,10 @@ BuildsOperationUpdateFolder::FindScavengeContent(const ScavengeSource& Source,
}
else
{
- ZEN_OPERATION_LOG_WARN(m_LogOutput,
- "Scavenged state file at '{}' for '{}' is invalid, skipping scavenging for sequence {}",
- Source.StateFilePath,
- Source.Path,
- SequenceHash);
+ ZEN_WARN("Scavenged state file at '{}' for '{}' is invalid, skipping scavenging for sequence {}",
+ Source.StateFilePath,
+ Source.Path,
+ SequenceHash);
}
}
}
@@ -2928,7 +3012,10 @@ BuildsOperationUpdateFolder::CheckRequiredDiskSpace(const tsl::robin_map<std::st
if (Space.Free < (RequiredSpace + 16u * 1024u * 1024u))
{
throw std::runtime_error(
- fmt::format("Not enough free space for target path '{}', {} of free space is needed", m_Path, RequiredSpace));
+ fmt::format("Not enough free space for target path '{}', {} of free space is needed but only {} is available",
+ m_Path,
+ NiceBytes(RequiredSpace),
+ NiceBytes(Space.Free)));
}
}
@@ -2980,18 +3067,13 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
FilteredRate& FilteredDownloadedBytesPerSecond,
FilteredRate& FilteredWrittenBytesPerSecond)
{
- std::filesystem::path ExistingCompressedChunkPath;
- if (!m_Options.PrimeCacheOnly)
+ const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex];
+ std::filesystem::path ExistingCompressedChunkPath = FindDownloadedChunk(ChunkHash);
+ if (!ExistingCompressedChunkPath.empty())
{
- const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex];
- ExistingCompressedChunkPath = FindDownloadedChunk(ChunkHash);
- if (!ExistingCompressedChunkPath.empty())
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
{
- m_DownloadStats.RequestsCompleteCount++;
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
+ FilteredDownloadedBytesPerSecond.Stop();
}
}
if (!m_AbortFlag)
@@ -3009,7 +3091,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
&FilteredWrittenBytesPerSecond,
RemoteChunkIndex,
ChunkTargetPtrs = std::move(ChunkTargetPtrs),
- CompressedChunkPath = std::move(ExistingCompressedChunkPath)](std::atomic<bool>& AbortFlag) mutable {
+ CompressedChunkPath = std::move(ExistingCompressedChunkPath)](std::atomic<bool>& AbortFlag) {
if (!AbortFlag)
{
ZEN_TRACE_CPU("Async_WritePreDownloadedChunk");
@@ -3027,11 +3109,11 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
bool NeedHashVerify =
WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart));
- WritePartsComplete++;
+ bool WritePartsDone = WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount;
if (!AbortFlag)
{
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsDone)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -3039,11 +3121,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
std::error_code Ec = TryRemoveFile(CompressedChunkPath);
if (Ec)
{
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- CompressedChunkPath,
- Ec.value(),
- Ec.message());
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", CompressedChunkPath, Ec.value(), Ec.message());
}
std::vector<uint32_t> CompletedSequences =
@@ -3085,6 +3163,8 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
DownloadBuildBlob(RemoteChunkIndex,
ExistsResult,
Work,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
[this,
&ExistsResult,
SequenceIndexChunksLeftToWriteCounters,
@@ -3092,19 +3172,11 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
&Work,
&WritePartsComplete,
TotalPartWriteCount,
- TotalRequestCount,
RemoteChunkIndex,
- &FilteredDownloadedBytesPerSecond,
&FilteredWrittenBytesPerSecond,
ChunkTargetPtrs = std::move(ChunkTargetPtrs)](IoBuffer&& Payload) mutable {
- if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
- IoBufferFileReference FileRef;
- bool EnableBacklog = Payload.GetFileReference(FileRef);
- AsyncWriteDownloadedChunk(m_Options.ZenFolderPath,
- RemoteChunkIndex,
+ AsyncWriteDownloadedChunk(RemoteChunkIndex,
+ ExistsResult,
std::move(ChunkTargetPtrs),
WriteCache,
Work,
@@ -3112,8 +3184,7 @@ BuildsOperationUpdateFolder::WriteLooseChunk(const uint32_t RemoteChunkInd
SequenceIndexChunksLeftToWriteCounters,
WritePartsComplete,
TotalPartWriteCount,
- FilteredWrittenBytesPerSecond,
- EnableBacklog);
+ FilteredWrittenBytesPerSecond);
});
}
});
@@ -3125,6 +3196,8 @@ void
BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkIndex,
const BlobsExistsResult& ExistsResult,
ParallelWork& Work,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& Payload)>&& OnDownloaded)
{
const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex];
@@ -3140,7 +3213,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
uint64_t BlobSize = BuildBlob.GetSize();
m_DownloadStats.DownloadedChunkCount++;
m_DownloadStats.DownloadedChunkByteCount += BlobSize;
- m_DownloadStats.RequestsCompleteCount++;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
OnDownloaded(std::move(BuildBlob));
}
else
@@ -3157,16 +3233,11 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
m_NetworkPool,
m_DownloadStats.DownloadedChunkByteCount,
m_DownloadStats.MultipartAttachmentCount,
- [this, &Work, ChunkHash, RemoteChunkIndex, OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) mutable {
+ [this, &FilteredDownloadedBytesPerSecond, TotalRequestCount, OnDownloaded = std::move(OnDownloaded)](IoBuffer&& Payload) {
m_DownloadStats.DownloadedChunkCount++;
- m_DownloadStats.RequestsCompleteCount++;
-
- if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache)
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
{
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- ChunkHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(Payload)));
+ FilteredDownloadedBytesPerSecond.Stop();
}
OnDownloaded(std::move(Payload));
@@ -3174,26 +3245,34 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
}
else
{
- BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash);
- if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache)
+ try
{
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- ChunkHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(BuildBlob)));
+ BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash);
}
- if (!BuildBlob)
+ catch (const std::exception&)
{
- throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash));
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
- if (!m_Options.PrimeCacheOnly)
+ if (!m_AbortFlag)
{
+ if (!BuildBlob)
+ {
+ throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash));
+ }
+
if (!m_AbortFlag)
{
uint64_t BlobSize = BuildBlob.GetSize();
m_DownloadStats.DownloadedChunkCount++;
m_DownloadStats.DownloadedChunkByteCount += BlobSize;
- m_DownloadStats.RequestsCompleteCount++;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
OnDownloaded(std::move(BuildBlob));
}
@@ -3208,6 +3287,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
size_t BlockRangeStartIndex,
size_t BlockRangeCount,
const BlobsExistsResult& ExistsResult,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
@@ -3222,6 +3303,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
IoBuffer&& BlockRangeBuffer,
size_t BlockRangeStartIndex,
std::span<const std::pair<uint64_t, uint64_t>> BlockOffsetAndLengths,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
const std::function<void(IoBuffer && InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
@@ -3229,63 +3312,23 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize();
m_DownloadStats.DownloadedBlockCount++;
m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize;
- m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size();
-
- std::filesystem::path BlockChunkPath;
-
- // Check if the dowloaded block is file based and we can move it directly without rewriting it
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(BlockOffsetAndLengths.size()) + BlockOffsetAndLengths.size() ==
+ TotalRequestCount)
{
- IoBufferFileReference FileRef;
- if (BlockRangeBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) &&
- (FileRef.FileChunkSize == BlockRangeBufferSize))
- {
- ZEN_TRACE_CPU("MoveTempPartialBlock");
-
- std::error_code Ec;
- std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec);
- if (!Ec)
- {
- BlockRangeBuffer.SetDeleteOnClose(false);
- BlockRangeBuffer = {};
-
- IoHashStream RangeId;
- for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths)
- {
- RangeId.Append(&Range.first, sizeof(uint64_t));
- RangeId.Append(&Range.second, sizeof(uint64_t));
- }
-
- BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash());
- RenameFile(TempBlobPath, BlockChunkPath, Ec);
- if (Ec)
- {
- BlockChunkPath = std::filesystem::path{};
-
- // Re-open the temp file again
- BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete);
- BlockRangeBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockRangeBufferSize, true);
- BlockRangeBuffer.SetDeleteOnClose(true);
- }
- }
- }
+ FilteredDownloadedBytesPerSecond.Stop();
}
- if (BlockChunkPath.empty() && (BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize))
+ IoHashStream RangeId;
+ for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths)
{
- ZEN_TRACE_CPU("WriteTempPartialBlock");
-
- IoHashStream RangeId;
- for (const std::pair<uint64_t, uint64_t>& Range : BlockOffsetAndLengths)
- {
- RangeId.Append(&Range.first, sizeof(uint64_t));
- RangeId.Append(&Range.second, sizeof(uint64_t));
- }
-
- // Could not be moved and rather large, lets store it on disk
- BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash());
- TemporaryFile::SafeWriteFile(BlockChunkPath, BlockRangeBuffer);
- BlockRangeBuffer = {};
+ RangeId.Append(&Range.first, sizeof(uint64_t));
+ RangeId.Append(&Range.second, sizeof(uint64_t));
}
+ std::filesystem::path BlockChunkPath =
+ TryMoveDownloadedChunk(BlockRangeBuffer,
+ m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()),
+ /* ForceDiskBased */ BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize);
+
if (!m_AbortFlag)
{
OnDownloaded(std::move(BlockRangeBuffer), std::move(BlockChunkPath), BlockRangeStartIndex, BlockOffsetAndLengths);
@@ -3337,6 +3380,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(PayloadBuffer),
SubRangeStartIndex,
std::vector<std::pair<uint64_t, uint64_t>>{std::make_pair(0u, SubRange.second)},
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
SubRangeCountComplete += SubRangeCount;
continue;
@@ -3361,6 +3406,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(RangeBuffers.PayloadBuffer),
SubRangeStartIndex,
RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
SubRangeCountComplete += SubRangeCount;
continue;
@@ -3371,6 +3418,8 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
std::move(RangeBuffers.PayloadBuffer),
SubRangeStartIndex,
RangeBuffers.Ranges,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
OnDownloaded);
SubRangeCountComplete += SubRangeCount;
continue;
@@ -3383,60 +3432,97 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount);
- BuildStorageBase::BuildBlobRanges RangeBuffers =
- m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges);
- if (m_AbortFlag)
+ BuildStorageBase::BuildBlobRanges RangeBuffers;
+
+ try
{
- break;
+ RangeBuffers = m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges);
}
- if (RangeBuffers.PayloadBuffer)
+ catch (const std::exception&)
{
- if (RangeBuffers.Ranges.empty())
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
{
- // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3
- // Upload to cache (if enabled) and use the whole payload for the remaining ranges
+ throw;
+ }
+ }
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ if (!m_AbortFlag)
+ {
+ if (RangeBuffers.PayloadBuffer)
+ {
+ if (RangeBuffers.Ranges.empty())
{
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- BlockDescription.BlockHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer}));
+ // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3
+ // Upload to cache (if enabled) and use the whole payload for the remaining ranges
+
+ const uint64_t Size = RangeBuffers.PayloadBuffer.GetSize();
+
+ const bool PopulateCache = !ExistsInCache && m_Storage.CacheStorage && m_Options.PopulateCache;
+
+ std::filesystem::path BlockPath =
+ TryMoveDownloadedChunk(RangeBuffers.PayloadBuffer,
+ m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(),
+ /* ForceDiskBased */ PopulateCache || Size > m_Options.MaximumInMemoryPayloadSize);
+ if (!BlockPath.empty())
+ {
+ RangeBuffers.PayloadBuffer = IoBufferBuilder::MakeFromFile(BlockPath);
+ if (!RangeBuffers.PayloadBuffer)
+ {
+ throw std::runtime_error(
+ fmt::format("Failed to read block {} from temporary path '{}'", BlockDescription.BlockHash, BlockPath));
+ }
+ RangeBuffers.PayloadBuffer.SetDeleteOnClose(true);
+ }
+
+ if (PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlockDescription.BlockHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(RangeBuffers.PayloadBuffer)));
+ }
+
if (m_AbortFlag)
{
break;
}
- }
- SubRangeCount = Ranges.size() - SubRangeCountComplete;
- ProcessDownload(BlockDescription,
- std::move(RangeBuffers.PayloadBuffer),
- SubRangeStartIndex,
- RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
- OnDownloaded);
+ SubRangeCount = Ranges.size() - SubRangeCountComplete;
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
+ OnDownloaded);
+ }
+ else
+ {
+ if (RangeBuffers.Ranges.size() != SubRanges.size())
+ {
+ throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges",
+ SubRanges.size(),
+ BlockDescription.BlockHash,
+ RangeBuffers.Ranges.size()));
+ }
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangeBuffers.Ranges,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
+ OnDownloaded);
+ }
}
else
{
- if (RangeBuffers.Ranges.size() != SubRanges.size())
- {
- throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges",
- SubRanges.size(),
- BlockDescription.BlockHash,
- RangeBuffers.Ranges.size()));
- }
- ProcessDownload(BlockDescription,
- std::move(RangeBuffers.PayloadBuffer),
- SubRangeStartIndex,
- RangeBuffers.Ranges,
- OnDownloaded);
+ throw std::runtime_error(
+ fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount));
}
- }
- else
- {
- throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount));
- }
- SubRangeCountComplete += SubRangeCount;
+ SubRangeCountComplete += SubRangeCount;
+ }
}
}
@@ -3654,7 +3740,7 @@ BuildsOperationUpdateFolder::WriteLocalChunkToCache(CloneQueryInterface* C
if (m_Options.IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Copied {} from {}", NiceBytes(CacheLocalFileBytesRead), SourceFilePath);
+ ZEN_INFO("Copied {} from {}", NiceBytes(CacheLocalFileBytesRead), SourceFilePath);
}
std::vector<uint32_t> Result;
@@ -4149,8 +4235,8 @@ BuildsOperationUpdateFolder::WritePartialBlockChunksToCache(const ChunkBlockDesc
}
void
-BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::path& ZenFolderPath,
- uint32_t RemoteChunkIndex,
+BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(uint32_t RemoteChunkIndex,
+ const BlobsExistsResult& ExistsResult,
std::vector<const ChunkedContentLookup::ChunkSequenceLocation*>&& ChunkTargetPtrs,
BufferedWriteFileCache& WriteCache,
ParallelWork& Work,
@@ -4158,8 +4244,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa
std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters,
std::atomic<uint64_t>& WritePartsComplete,
const uint64_t TotalPartWriteCount,
- FilteredRate& FilteredWrittenBytesPerSecond,
- bool EnableBacklog)
+ FilteredRate& FilteredWrittenBytesPerSecond)
{
ZEN_TRACE_CPU("AsyncWriteDownloadedChunk");
@@ -4167,48 +4252,32 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa
const uint64_t Size = Payload.GetSize();
- std::filesystem::path CompressedChunkPath;
+ const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash);
+
+ const bool PopulateCache = !ExistsInCache && m_Storage.CacheStorage && m_Options.PopulateCache;
- // Check if the dowloaded chunk is file based and we can move it directly without rewriting it
+ std::filesystem::path CompressedChunkPath =
+ TryMoveDownloadedChunk(Payload,
+ m_TempDownloadFolderPath / ChunkHash.ToHexString(),
+ /* ForceDiskBased */ PopulateCache || Size > m_Options.MaximumInMemoryPayloadSize);
+ if (PopulateCache)
{
- IoBufferFileReference FileRef;
- if (Payload.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == Size))
+ IoBuffer CacheBlob = IoBufferBuilder::MakeFromFile(CompressedChunkPath);
+ if (CacheBlob)
{
- ZEN_TRACE_CPU("MoveTempChunk");
- std::error_code Ec;
- std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec);
- if (!Ec)
- {
- Payload.SetDeleteOnClose(false);
- Payload = {};
- CompressedChunkPath = m_TempDownloadFolderPath / ChunkHash.ToHexString();
- RenameFile(TempBlobPath, CompressedChunkPath, Ec);
- if (Ec)
- {
- CompressedChunkPath = std::filesystem::path{};
-
- // Re-open the temp file again
- BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete);
- Payload = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, Size, true);
- Payload.SetDeleteOnClose(true);
- }
- }
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ ChunkHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(CacheBlob)));
}
}
- if (CompressedChunkPath.empty() && (Size > m_Options.MaximumInMemoryPayloadSize))
- {
- ZEN_TRACE_CPU("WriteTempChunk");
- // Could not be moved and rather large, lets store it on disk
- CompressedChunkPath = m_TempDownloadFolderPath / ChunkHash.ToHexString();
- TemporaryFile::SafeWriteFile(CompressedChunkPath, Payload);
- Payload = {};
- }
+ IoBufferFileReference FileRef;
+ bool EnableBacklog = !CompressedChunkPath.empty() || Payload.GetFileReference(FileRef);
Work.ScheduleWork(
m_IOWorkerPool,
- [&ZenFolderPath,
- this,
+ [this,
SequenceIndexChunksLeftToWriteCounters,
&Work,
CompressedChunkPath,
@@ -4244,8 +4313,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa
bool NeedHashVerify = WriteCompressedChunkToCache(ChunkHash, ChunkTargetPtrs, WriteCache, std::move(CompressedPart));
if (!m_AbortFlag)
{
- WritePartsComplete++;
- if (WritePartsComplete == TotalPartWriteCount)
+ if (WritePartsComplete.fetch_add(1) + 1 == TotalPartWriteCount)
{
FilteredWrittenBytesPerSecond.Stop();
}
@@ -4255,11 +4323,7 @@ BuildsOperationUpdateFolder::AsyncWriteDownloadedChunk(const std::filesystem::pa
std::error_code Ec = TryRemoveFile(CompressedChunkPath);
if (Ec)
{
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput,
- "Failed removing file '{}', reason: ({}) {}",
- CompressedChunkPath,
- Ec.value(),
- Ec.message());
+ ZEN_DEBUG("Failed removing file '{}', reason: ({}) {}", CompressedChunkPath, Ec.value(), Ec.message());
}
}
@@ -4412,7 +4476,8 @@ BuildsOperationUpdateFolder::VerifySequence(uint32_t RemoteSequenceIndex)
////////////////////// BuildsOperationUploadFolder
-BuildsOperationUploadFolder::BuildsOperationUploadFolder(OperationLogOutput& OperationLogOutput,
+BuildsOperationUploadFolder::BuildsOperationUploadFolder(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -4423,7 +4488,8 @@ BuildsOperationUploadFolder::BuildsOperationUploadFolder(OperationLogOutput&
bool CreateBuild,
const CbObject& MetaData,
const Options& Options)
-: m_LogOutput(OperationLogOutput)
+: m_Log(Log)
+, m_Progress(Progress)
, m_Storage(Storage)
, m_AbortFlag(AbortFlag)
, m_PauseFlag(PauseFlag)
@@ -4476,9 +4542,7 @@ BuildsOperationUploadFolder::PrepareBuild()
}
else if (m_Options.AllowMultiparts)
{
- ZEN_OPERATION_LOG_WARN(m_LogOutput,
- "PreferredMultipartChunkSize is unknown. Defaulting to '{}'",
- NiceBytes(Result.PreferredMultipartChunkSize));
+ ZEN_WARN("PreferredMultipartChunkSize is unknown. Defaulting to '{}'", NiceBytes(Result.PreferredMultipartChunkSize));
}
}
@@ -4538,10 +4602,8 @@ BuildsOperationUploadFolder::ReadFolder()
return true;
},
m_IOWorkerPool,
- m_LogOutput.GetProgressUpdateDelayMS(),
- [&](bool, std::ptrdiff_t) {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Found {} files in '{}'...", LocalFolderScanStats.AcceptedFileCount.load(), m_Path);
- },
+ m_Progress.GetProgressUpdateDelayMS(),
+ [&](bool, std::ptrdiff_t) { ZEN_INFO("Found {} files in '{}'...", LocalFolderScanStats.AcceptedFileCount.load(), m_Path); },
m_AbortFlag);
Part.TotalRawSize = std::accumulate(Part.Content.RawSizes.begin(), Part.Content.RawSizes.end(), std::uint64_t(0));
@@ -4655,10 +4717,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId,
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Reading {} parts took {}",
- UploadParts.size(),
- NiceTimeSpanMs(ReadPartsTimer.GetElapsedTimeMs()));
+ ZEN_INFO("Reading {} parts took {}", UploadParts.size(), NiceTimeSpanMs(ReadPartsTimer.GetElapsedTimeMs()));
}
const uint32_t PartsUploadStepCount = gsl::narrow<uint32_t>(uint32_t(PartTaskSteps::StepCount) * UploadParts.size());
@@ -4669,7 +4728,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId,
const uint32_t CleanupStep = FinalizeBuildStep + 1;
const uint32_t StepCount = CleanupStep + 1;
- auto EndProgress = MakeGuard([&]() { m_LogOutput.SetLogOperationProgress(StepCount, StepCount); });
+ auto EndProgress = MakeGuard([&]() { m_Progress.SetLogOperationProgress(StepCount, StepCount); });
Stopwatch ProcessTimer;
@@ -4677,7 +4736,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId,
CreateDirectories(m_Options.TempDir);
auto _ = MakeGuard([&]() { CleanAndRemoveDirectory(m_IOWorkerPool, m_AbortFlag, m_PauseFlag, m_Options.TempDir); });
- m_LogOutput.SetLogOperationProgress(PrepareBuildStep, StepCount);
+ m_Progress.SetLogOperationProgress(PrepareBuildStep, StepCount);
m_PrepBuildResultFuture = m_NetworkPool.EnqueueTask(std::packaged_task<PrepareBuildResult()>{[this] { return PrepareBuild(); }},
WorkerThreadPool::EMode::EnableBacklog);
@@ -4694,7 +4753,7 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId,
}
}
- m_LogOutput.SetLogOperationProgress(FinalizeBuildStep, StepCount);
+ m_Progress.SetLogOperationProgress(FinalizeBuildStep, StepCount);
if (m_CreateBuild && !m_AbortFlag)
{
@@ -4702,11 +4761,11 @@ BuildsOperationUploadFolder::Execute(const Oid& BuildPartId,
m_Storage.BuildStorage->FinalizeBuild(m_BuildId);
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "FinalizeBuild took {}", NiceTimeSpanMs(FinalizeBuildTimer.GetElapsedTimeMs()));
+ ZEN_INFO("FinalizeBuild took {}", NiceTimeSpanMs(FinalizeBuildTimer.GetElapsedTimeMs()));
}
}
- m_LogOutput.SetLogOperationProgress(CleanupStep, StepCount);
+ m_Progress.SetLogOperationProgress(CleanupStep, StepCount);
std::vector<std::pair<Oid, std::string>> Result;
Result.reserve(UploadParts.size());
@@ -4864,235 +4923,256 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent&
{
ZEN_TRACE_CPU("GenerateBuildBlocks");
const std::size_t NewBlockCount = NewBlockChunks.size();
- if (NewBlockCount > 0)
+ if (NewBlockCount == 0)
{
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Generate Blocks"));
- OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr);
+ return;
+ }
- OutBlocks.BlockDescriptions.resize(NewBlockCount);
- OutBlocks.BlockSizes.resize(NewBlockCount);
- OutBlocks.BlockMetaDatas.resize(NewBlockCount);
- OutBlocks.BlockHeaders.resize(NewBlockCount);
- OutBlocks.MetaDataHasBeenUploaded.resize(NewBlockCount, 0);
- OutBlocks.BlockHashToBlockIndex.reserve(NewBlockCount);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Generate Blocks");
+
+ OutBlocks.BlockDescriptions.resize(NewBlockCount);
+ OutBlocks.BlockSizes.resize(NewBlockCount);
+ OutBlocks.BlockMetaDatas.resize(NewBlockCount);
+ OutBlocks.BlockHeaders.resize(NewBlockCount);
+ OutBlocks.MetaDataHasBeenUploaded.resize(NewBlockCount, 0);
+ OutBlocks.BlockHashToBlockIndex.reserve(NewBlockCount);
+
+ RwLock Lock;
+ FilteredRate FilteredGeneratedBytesPerSecond;
+ FilteredRate FilteredUploadedBytesPerSecond;
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ std::atomic<uint64_t> QueuedPendingBlocksForUpload = 0;
+
+ GenerateBuildBlocksContext Context{.Work = Work,
+ .GenerateBlobsPool = m_IOWorkerPool,
+ .UploadBlocksPool = m_NetworkPool,
+ .FilteredGeneratedBytesPerSecond = FilteredGeneratedBytesPerSecond,
+ .FilteredUploadedBytesPerSecond = FilteredUploadedBytesPerSecond,
+ .QueuedPendingBlocksForUpload = QueuedPendingBlocksForUpload,
+ .Lock = Lock,
+ .OutBlocks = OutBlocks,
+ .GenerateBlocksStats = GenerateBlocksStats,
+ .UploadStats = UploadStats,
+ .NewBlockCount = NewBlockCount};
+
+ ScheduleBlockGeneration(Context, Content, Lookup, NewBlockChunks);
+
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ ZEN_UNUSED(PendingWork);
+
+ FilteredGeneratedBytesPerSecond.Update(GenerateBlocksStats.GeneratedBlockByteCount.load());
+ FilteredUploadedBytesPerSecond.Update(UploadStats.BlocksBytes.load());
+
+ std::string Details = fmt::format("Generated {}/{} ({}, {}B/s). Uploaded {}/{} ({}, {}bits/s)",
+ GenerateBlocksStats.GeneratedBlockCount.load(),
+ NewBlockCount,
+ NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount.load()),
+ NiceNum(FilteredGeneratedBytesPerSecond.GetCurrent()),
+ UploadStats.BlockCount.load(),
+ NewBlockCount,
+ NiceBytes(UploadStats.BlocksBytes.load()),
+ NiceNum(FilteredUploadedBytesPerSecond.GetCurrent() * 8));
+
+ ProgressBar->UpdateState({.Task = "Generating blocks",
+ .Details = Details,
+ .TotalCount = gsl::narrow<uint64_t>(NewBlockCount),
+ .RemainingCount = gsl::narrow<uint64_t>(NewBlockCount - GenerateBlocksStats.GeneratedBlockCount.load()),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
+ });
- RwLock Lock;
+ ZEN_ASSERT(m_AbortFlag || QueuedPendingBlocksForUpload.load() == 0);
- WorkerThreadPool& GenerateBlobsPool = m_IOWorkerPool;
- WorkerThreadPool& UploadBlocksPool = m_NetworkPool;
+ ProgressBar->Finish();
- FilteredRate FilteredGeneratedBytesPerSecond;
- FilteredRate FilteredUploadedBytesPerSecond;
+ GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS = FilteredGeneratedBytesPerSecond.GetElapsedTimeUS();
+ UploadStats.ElapsedWallTimeUS = FilteredUploadedBytesPerSecond.GetElapsedTimeUS();
+}
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+void
+BuildsOperationUploadFolder::ScheduleBlockGeneration(GenerateBuildBlocksContext& Context,
+ const ChunkedFolderContent& Content,
+ const ChunkedContentLookup& Lookup,
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks)
+{
+ for (size_t BlockIndex = 0; BlockIndex < Context.NewBlockCount; BlockIndex++)
+ {
+ if (Context.Work.IsAborted())
+ {
+ break;
+ }
+ const std::vector<uint32_t>& ChunksInBlock = NewBlockChunks[BlockIndex];
+ Context.Work.ScheduleWork(
+ Context.GenerateBlobsPool,
+ [this, &Context, &Content, &Lookup, ChunksInBlock, BlockIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("GenerateBuildBlocks_Generate");
- std::atomic<uint64_t> QueuedPendingBlocksForUpload = 0;
+ Context.FilteredGeneratedBytesPerSecond.Start();
- for (size_t BlockIndex = 0; BlockIndex < NewBlockCount; BlockIndex++)
- {
- if (Work.IsAborted())
- {
- break;
- }
- const std::vector<uint32_t>& ChunksInBlock = NewBlockChunks[BlockIndex];
- Work.ScheduleWork(
- GenerateBlobsPool,
- [this,
- &Content,
- &Lookup,
- &Work,
- &UploadBlocksPool,
- NewBlockCount,
- ChunksInBlock,
- &Lock,
- &OutBlocks,
- &GenerateBlocksStats,
- &UploadStats,
- &FilteredGeneratedBytesPerSecond,
- &QueuedPendingBlocksForUpload,
- &FilteredUploadedBytesPerSecond,
- BlockIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
+ Stopwatch GenerateTimer;
+ CompressedBuffer CompressedBlock =
+ GenerateBlock(Content, Lookup, ChunksInBlock, Context.OutBlocks.BlockDescriptions[BlockIndex]);
+ if (m_Options.IsVerbose)
{
- ZEN_TRACE_CPU("GenerateBuildBlocks_Generate");
+ ZEN_INFO("Generated block {} ({}) containing {} chunks in {}",
+ Context.OutBlocks.BlockDescriptions[BlockIndex].BlockHash,
+ NiceBytes(CompressedBlock.GetCompressedSize()),
+ Context.OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(),
+ NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs()));
+ }
- FilteredGeneratedBytesPerSecond.Start();
+ Context.OutBlocks.BlockSizes[BlockIndex] = CompressedBlock.GetCompressedSize();
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("createdBy", "zen");
+ Context.OutBlocks.BlockMetaDatas[BlockIndex] = Writer.Save();
+ }
+ Context.GenerateBlocksStats.GeneratedBlockByteCount += Context.OutBlocks.BlockSizes[BlockIndex];
+ Context.GenerateBlocksStats.GeneratedBlockCount++;
- Stopwatch GenerateTimer;
- CompressedBuffer CompressedBlock =
- GenerateBlock(Content, Lookup, ChunksInBlock, OutBlocks.BlockDescriptions[BlockIndex]);
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Generated block {} ({}) containing {} chunks in {}",
- OutBlocks.BlockDescriptions[BlockIndex].BlockHash,
- NiceBytes(CompressedBlock.GetCompressedSize()),
- OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(),
- NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs()));
- }
+ Context.Lock.WithExclusiveLock([&]() {
+ Context.OutBlocks.BlockHashToBlockIndex.insert_or_assign(Context.OutBlocks.BlockDescriptions[BlockIndex].BlockHash,
+ BlockIndex);
+ });
- OutBlocks.BlockSizes[BlockIndex] = CompressedBlock.GetCompressedSize();
- {
- CbObjectWriter Writer;
- Writer.AddString("createdBy", "zen");
- OutBlocks.BlockMetaDatas[BlockIndex] = Writer.Save();
- }
- GenerateBlocksStats.GeneratedBlockByteCount += OutBlocks.BlockSizes[BlockIndex];
- GenerateBlocksStats.GeneratedBlockCount++;
+ {
+ std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments();
+ ZEN_ASSERT(Segments.size() >= 2);
+ Context.OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]);
+ }
- Lock.WithExclusiveLock([&]() {
- OutBlocks.BlockHashToBlockIndex.insert_or_assign(OutBlocks.BlockDescriptions[BlockIndex].BlockHash, BlockIndex);
- });
+ if (Context.GenerateBlocksStats.GeneratedBlockCount == Context.NewBlockCount)
+ {
+ Context.FilteredGeneratedBytesPerSecond.Stop();
+ }
+ if (Context.QueuedPendingBlocksForUpload.load() > 16)
+ {
+ std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments();
+ ZEN_ASSERT(Segments.size() >= 2);
+ Context.OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]);
+ }
+ else
+ {
+ if (!m_AbortFlag)
{
- std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments();
- ZEN_ASSERT(Segments.size() >= 2);
- OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]);
+ Context.QueuedPendingBlocksForUpload++;
+ Context.Work.ScheduleWork(
+ Context.UploadBlocksPool,
+ [this, &Context, BlockIndex, Payload = std::move(CompressedBlock)](std::atomic<bool>&) mutable {
+ UploadGeneratedBlock(Context, BlockIndex, std::move(Payload));
+ });
}
+ }
+ }
+ });
+ }
+}
- if (GenerateBlocksStats.GeneratedBlockCount == NewBlockCount)
- {
- FilteredGeneratedBytesPerSecond.Stop();
- }
+void
+BuildsOperationUploadFolder::UploadGeneratedBlock(GenerateBuildBlocksContext& Context, size_t BlockIndex, CompressedBuffer Payload)
+{
+ auto _ = MakeGuard([&Context] { Context.QueuedPendingBlocksForUpload--; });
+ if (m_AbortFlag)
+ {
+ return;
+ }
- if (QueuedPendingBlocksForUpload.load() > 16)
- {
- std::span<const SharedBuffer> Segments = CompressedBlock.GetCompressed().GetSegments();
- ZEN_ASSERT(Segments.size() >= 2);
- OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]);
- }
- else
- {
- if (!m_AbortFlag)
- {
- QueuedPendingBlocksForUpload++;
+ if (Context.GenerateBlocksStats.GeneratedBlockCount == Context.NewBlockCount)
+ {
+ ZEN_TRACE_CPU("GenerateBuildBlocks_Save");
- Work.ScheduleWork(
- UploadBlocksPool,
- [this,
- NewBlockCount,
- &GenerateBlocksStats,
- &UploadStats,
- &FilteredUploadedBytesPerSecond,
- &QueuedPendingBlocksForUpload,
- &OutBlocks,
- BlockIndex,
- Payload = std::move(CompressedBlock)](std::atomic<bool>&) mutable {
- auto _ = MakeGuard([&QueuedPendingBlocksForUpload] { QueuedPendingBlocksForUpload--; });
- if (!m_AbortFlag)
- {
- if (GenerateBlocksStats.GeneratedBlockCount == NewBlockCount)
- {
- ZEN_TRACE_CPU("GenerateBuildBlocks_Save");
-
- FilteredUploadedBytesPerSecond.Stop();
- std::span<const SharedBuffer> Segments = Payload.GetCompressed().GetSegments();
- ZEN_ASSERT(Segments.size() >= 2);
- OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]);
- }
- else
- {
- ZEN_TRACE_CPU("GenerateBuildBlocks_Upload");
-
- FilteredUploadedBytesPerSecond.Start();
-
- const CbObject BlockMetaData =
- BuildChunkBlockDescription(OutBlocks.BlockDescriptions[BlockIndex],
- OutBlocks.BlockMetaDatas[BlockIndex]);
-
- const IoHash& BlockHash = OutBlocks.BlockDescriptions[BlockIndex].BlockHash;
- const uint64_t CompressedBlockSize = Payload.GetCompressedSize();
-
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- BlockHash,
- ZenContentType::kCompressedBinary,
- Payload.GetCompressed());
- }
-
- m_Storage.BuildStorage->PutBuildBlob(m_BuildId,
- BlockHash,
- ZenContentType::kCompressedBinary,
- std::move(Payload).GetCompressed());
- UploadStats.BlocksBytes += CompressedBlockSize;
-
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} ({}) containing {} chunks",
- BlockHash,
- NiceBytes(CompressedBlockSize),
- OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
- }
-
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
- }
-
- bool MetadataSucceeded =
- m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
- if (MetadataSucceeded)
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} metadata ({})",
- BlockHash,
- NiceBytes(BlockMetaData.GetSize()));
- }
-
- OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
- UploadStats.BlocksBytes += BlockMetaData.GetSize();
- }
-
- UploadStats.BlockCount++;
- if (UploadStats.BlockCount == NewBlockCount)
- {
- FilteredUploadedBytesPerSecond.Stop();
- }
- }
- }
- });
- }
- }
- }
- });
+ Context.FilteredUploadedBytesPerSecond.Stop();
+ std::span<const SharedBuffer> Segments = Payload.GetCompressed().GetSegments();
+ ZEN_ASSERT(Segments.size() >= 2);
+ Context.OutBlocks.BlockHeaders[BlockIndex] = CompositeBuffer(Segments[0], Segments[1]);
+ return;
+ }
+
+ ZEN_TRACE_CPU("GenerateBuildBlocks_Upload");
+
+ Context.FilteredUploadedBytesPerSecond.Start();
+
+ const CbObject BlockMetaData =
+ BuildChunkBlockDescription(Context.OutBlocks.BlockDescriptions[BlockIndex], Context.OutBlocks.BlockMetaDatas[BlockIndex]);
+
+ const IoHash& BlockHash = Context.OutBlocks.BlockDescriptions[BlockIndex].BlockHash;
+ const uint64_t CompressedBlockSize = Payload.GetCompressedSize();
+
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload.GetCompressed());
+ }
+
+ try
+ {
+ m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, std::move(Payload).GetCompressed());
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
}
+ }
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
- ZEN_UNUSED(PendingWork);
+ if (m_AbortFlag)
+ {
+ return;
+ }
- FilteredGeneratedBytesPerSecond.Update(GenerateBlocksStats.GeneratedBlockByteCount.load());
- FilteredUploadedBytesPerSecond.Update(UploadStats.BlocksBytes.load());
+ Context.UploadStats.BlocksBytes += CompressedBlockSize;
- std::string Details = fmt::format("Generated {}/{} ({}, {}B/s). Uploaded {}/{} ({}, {}bits/s)",
- GenerateBlocksStats.GeneratedBlockCount.load(),
- NewBlockCount,
- NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount.load()),
- NiceNum(FilteredGeneratedBytesPerSecond.GetCurrent()),
- UploadStats.BlockCount.load(),
- NewBlockCount,
- NiceBytes(UploadStats.BlocksBytes.load()),
- NiceNum(FilteredUploadedBytesPerSecond.GetCurrent() * 8));
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Uploaded block {} ({}) containing {} chunks",
+ BlockHash,
+ NiceBytes(CompressedBlockSize),
+ Context.OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
+ }
- Progress.UpdateState({.Task = "Generating blocks",
- .Details = Details,
- .TotalCount = gsl::narrow<uint64_t>(NewBlockCount),
- .RemainingCount = gsl::narrow<uint64_t>(NewBlockCount - GenerateBlocksStats.GeneratedBlockCount.load()),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
- });
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, std::vector<IoHash>({BlockHash}), std::vector<CbObject>({BlockMetaData}));
+ }
- ZEN_ASSERT(m_AbortFlag || QueuedPendingBlocksForUpload.load() == 0);
+ bool MetadataSucceeded = false;
+ try
+ {
+ MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
- Progress.Finish();
+ if (m_AbortFlag)
+ {
+ return;
+ }
- GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS = FilteredGeneratedBytesPerSecond.GetElapsedTimeUS();
- UploadStats.ElapsedWallTimeUS = FilteredUploadedBytesPerSecond.GetElapsedTimeUS();
+ if (MetadataSucceeded)
+ {
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Uploaded block {} metadata ({})", BlockHash, NiceBytes(BlockMetaData.GetSize()));
+ }
+
+ Context.OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
+ Context.UploadStats.BlocksBytes += BlockMetaData.GetSize();
+ }
+
+ Context.UploadStats.BlockCount++;
+ if (Context.UploadStats.BlockCount == Context.NewBlockCount)
+ {
+ Context.FilteredUploadedBytesPerSecond.Stop();
}
}
@@ -5220,7 +5300,7 @@ BuildsOperationUploadFolder::GenerateBlock(const ChunkedFolderContent& Content,
uint64_t RawSize = Chunk.GetSize();
const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock &&
- IsChunkCompressable(m_NonCompressableExtensionHashes, Content, Lookup, ChunkIndex);
+ IsChunkCompressable(m_NonCompressableExtensionHashes, Lookup, ChunkIndex);
const OodleCompressionLevel CompressionLevel =
ShouldCompressChunk ? OodleCompressionLevel::VeryFast : OodleCompressionLevel::None;
@@ -5259,9 +5339,9 @@ BuildsOperationUploadFolder::RebuildBlock(const ChunkedFolderContent& Content,
Content.ChunkedContent.ChunkRawSizes[ChunkIndex]);
ZEN_ASSERT_SLOW(IoHash::HashBuffer(Chunk) == Content.ChunkedContent.ChunkHashes[ChunkIndex]);
- const uint64_t RawSize = Chunk.GetSize();
- const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock &&
- IsChunkCompressable(m_NonCompressableExtensionHashes, Content, Lookup, ChunkIndex);
+ const uint64_t RawSize = Chunk.GetSize();
+ const bool ShouldCompressChunk =
+ RawSize >= m_Options.MinimumSizeForCompressInBlock && IsChunkCompressable(m_NonCompressableExtensionHashes, Lookup, ChunkIndex);
const OodleCompressionLevel CompressionLevel = ShouldCompressChunk ? OodleCompressionLevel::VeryFast : OodleCompressionLevel::None;
@@ -5287,158 +5367,40 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
ReuseBlocksStatistics ReuseBlocksStats;
UploadStatistics UploadStats;
GenerateBlocksStatistics GenerateBlocksStats;
+ LooseChunksStatistics LooseChunksStats;
- LooseChunksStatistics LooseChunksStats;
- ChunkedFolderContent LocalContent;
-
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::ChunkPartContent, StepCount);
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::ChunkPartContent, StepCount);
- Stopwatch ScanTimer;
- {
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Scan Folder"));
- OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr);
-
- FilteredRate FilteredBytesHashed;
- FilteredBytesHashed.Start();
- LocalContent = ChunkFolderContent(
- ChunkingStats,
- m_IOWorkerPool,
- m_Path,
- Part.Content,
- ChunkController,
- ChunkCache,
- m_LogOutput.GetProgressUpdateDelayMS(),
- [&](bool IsAborted, bool IsPaused, std::ptrdiff_t) {
- FilteredBytesHashed.Update(ChunkingStats.BytesHashed.load());
- std::string Details = fmt::format("{}/{} ({}/{}, {}B/s) scanned, {} ({}) chunks found",
- ChunkingStats.FilesProcessed.load(),
- Part.Content.Paths.size(),
- NiceBytes(ChunkingStats.BytesHashed.load()),
- NiceBytes(Part.TotalRawSize),
- NiceNum(FilteredBytesHashed.GetCurrent()),
- ChunkingStats.UniqueChunksFound.load(),
- NiceBytes(ChunkingStats.UniqueBytesFound.load()));
- Progress.UpdateState({.Task = "Scanning files ",
- .Details = Details,
- .TotalCount = Part.TotalRawSize,
- .RemainingCount = Part.TotalRawSize - ChunkingStats.BytesHashed.load(),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
- },
- m_AbortFlag,
- m_PauseFlag);
- FilteredBytesHashed.Stop();
- Progress.Finish();
- if (m_AbortFlag)
- {
- return;
- }
- }
-
- if (!m_Options.IsQuiet)
+ ChunkedFolderContent LocalContent = ScanPartContent(Part, ChunkController, ChunkCache, ChunkingStats);
+ if (m_AbortFlag)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Found {} ({}) files divided into {} ({}) unique chunks in '{}' in {}. Average hash rate {}B/sec",
- Part.Content.Paths.size(),
- NiceBytes(Part.TotalRawSize),
- ChunkingStats.UniqueChunksFound.load(),
- NiceBytes(ChunkingStats.UniqueBytesFound.load()),
- m_Path,
- NiceTimeSpanMs(ScanTimer.GetElapsedTimeMs()),
- NiceNum(GetBytesPerSecond(ChunkingStats.ElapsedWallTimeUS, ChunkingStats.BytesHashed)));
+ return;
}
const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalContent);
- std::vector<size_t> ReuseBlockIndexes;
- std::vector<uint32_t> NewBlockChunkIndexes;
-
if (PartIndex == 0)
{
- const PrepareBuildResult PrepBuildResult = m_PrepBuildResultFuture.get();
-
- m_FindBlocksStats.FindBlockTimeMS = PrepBuildResult.ElapsedTimeMs;
- m_FindBlocksStats.FoundBlockCount = PrepBuildResult.KnownBlocks.size();
-
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Build prepare took {}. {} took {}, payload size {}{}",
- NiceTimeSpanMs(PrepBuildResult.ElapsedTimeMs),
- m_CreateBuild ? "PutBuild" : "GetBuild",
- NiceTimeSpanMs(PrepBuildResult.PrepareBuildTimeMs),
- NiceBytes(PrepBuildResult.PayloadSize),
- m_Options.IgnoreExistingBlocks ? ""
- : fmt::format(". Found {} blocks in {}",
- PrepBuildResult.KnownBlocks.size(),
- NiceTimeSpanMs(PrepBuildResult.FindBlocksTimeMs)));
- }
-
- m_PreferredMultipartChunkSize = PrepBuildResult.PreferredMultipartChunkSize;
-
- m_LargeAttachmentSize = m_Options.AllowMultiparts ? m_PreferredMultipartChunkSize * 4u : (std::uint64_t)-1;
-
- m_KnownBlocks = std::move(PrepBuildResult.KnownBlocks);
+ ConsumePrepareBuildResult();
}
ZEN_ASSERT(m_PreferredMultipartChunkSize != 0);
ZEN_ASSERT(m_LargeAttachmentSize != 0);
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::CalculateDelta, StepCount);
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::CalculateDelta, StepCount);
Stopwatch BlockArrangeTimer;
- std::vector<std::uint32_t> LooseChunkIndexes;
- {
- bool EnableBlocks = true;
- std::vector<std::uint32_t> BlockChunkIndexes;
- for (uint32_t ChunkIndex = 0; ChunkIndex < LocalContent.ChunkedContent.ChunkHashes.size(); ChunkIndex++)
- {
- const uint64_t ChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[ChunkIndex];
- if (!EnableBlocks || ChunkRawSize == 0 || ChunkRawSize > m_Options.BlockParameters.MaxChunkEmbedSize)
- {
- LooseChunkIndexes.push_back(ChunkIndex);
- LooseChunksStats.ChunkByteCount += ChunkRawSize;
- }
- else
- {
- BlockChunkIndexes.push_back(ChunkIndex);
- FindBlocksStats.PotentialChunkByteCount += ChunkRawSize;
- }
- }
- FindBlocksStats.PotentialChunkCount += BlockChunkIndexes.size();
- LooseChunksStats.ChunkCount = LooseChunkIndexes.size();
-
- if (m_Options.IgnoreExistingBlocks)
- {
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Ignoring any existing blocks in store");
- }
- NewBlockChunkIndexes = std::move(BlockChunkIndexes);
- }
- else
- {
- ReuseBlockIndexes = FindReuseBlocks(m_LogOutput,
- m_Options.BlockReuseMinPercentLimit,
- m_Options.IsVerbose,
- ReuseBlocksStats,
- m_KnownBlocks,
- LocalContent.ChunkedContent.ChunkHashes,
- BlockChunkIndexes,
- NewBlockChunkIndexes);
- FindBlocksStats.AcceptedBlockCount += ReuseBlockIndexes.size();
-
- for (const ChunkBlockDescription& Description : m_KnownBlocks)
- {
- for (uint32_t ChunkRawLength : Description.ChunkRawLengths)
- {
- FindBlocksStats.FoundBlockByteCount += ChunkRawLength;
- }
- FindBlocksStats.FoundBlockChunkCount += Description.ChunkRawHashes.size();
- }
- }
- }
+ std::vector<uint32_t> LooseChunkIndexes;
+ std::vector<uint32_t> NewBlockChunkIndexes;
+ std::vector<size_t> ReuseBlockIndexes;
+ ClassifyChunksByBlockEligibility(LocalContent,
+ LooseChunkIndexes,
+ NewBlockChunkIndexes,
+ ReuseBlockIndexes,
+ LooseChunksStats,
+ FindBlocksStats,
+ ReuseBlocksStats);
std::vector<std::vector<uint32_t>> NewBlockChunks;
ArrangeChunksIntoBlocks(LocalContent, LocalLookup, NewBlockChunkIndexes, NewBlockChunks);
@@ -5460,43 +5422,43 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
: 0.0;
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Found {} chunks in {} ({}) blocks eligible for reuse in {}\n"
- " Reusing {} ({}) matching chunks in {} blocks ({:.1f}%)\n"
- " Accepting {} ({}) redundant chunks ({:.1f}%)\n"
- " Rejected {} ({}) chunks in {} blocks\n"
- " Arranged {} ({}) chunks in {} new blocks\n"
- " Keeping {} ({}) chunks as loose chunks\n"
- " Discovery completed in {}",
- FindBlocksStats.FoundBlockChunkCount,
- FindBlocksStats.FoundBlockCount,
- NiceBytes(FindBlocksStats.FoundBlockByteCount),
- NiceTimeSpanMs(FindBlocksStats.FindBlockTimeMS),
+ ZEN_INFO(
+ "Found {} chunks in {} ({}) blocks eligible for reuse in {}\n"
+ " Reusing {} ({}) matching chunks in {} blocks ({:.1f}%)\n"
+ " Accepting {} ({}) redundant chunks ({:.1f}%)\n"
+ " Rejected {} ({}) chunks in {} blocks\n"
+ " Arranged {} ({}) chunks in {} new blocks\n"
+ " Keeping {} ({}) chunks as loose chunks\n"
+ " Discovery completed in {}",
+ FindBlocksStats.FoundBlockChunkCount,
+ FindBlocksStats.FoundBlockCount,
+ NiceBytes(FindBlocksStats.FoundBlockByteCount),
+ NiceTimeSpanMs(FindBlocksStats.FindBlockTimeMS),
- ReuseBlocksStats.AcceptedChunkCount,
- NiceBytes(ReuseBlocksStats.AcceptedRawByteCount),
- FindBlocksStats.AcceptedBlockCount,
- AcceptedByteCountPercent,
+ ReuseBlocksStats.AcceptedChunkCount,
+ NiceBytes(ReuseBlocksStats.AcceptedRawByteCount),
+ FindBlocksStats.AcceptedBlockCount,
+ AcceptedByteCountPercent,
- ReuseBlocksStats.AcceptedReduntantChunkCount,
- NiceBytes(ReuseBlocksStats.AcceptedReduntantByteCount),
- AcceptedReduntantByteCountPercent,
+ ReuseBlocksStats.AcceptedReduntantChunkCount,
+ NiceBytes(ReuseBlocksStats.AcceptedReduntantByteCount),
+ AcceptedReduntantByteCountPercent,
- ReuseBlocksStats.RejectedChunkCount,
- NiceBytes(ReuseBlocksStats.RejectedByteCount),
- ReuseBlocksStats.RejectedBlockCount,
+ ReuseBlocksStats.RejectedChunkCount,
+ NiceBytes(ReuseBlocksStats.RejectedByteCount),
+ ReuseBlocksStats.RejectedBlockCount,
- FindBlocksStats.NewBlocksChunkCount,
- NiceBytes(FindBlocksStats.NewBlocksChunkByteCount),
- FindBlocksStats.NewBlocksCount,
+ FindBlocksStats.NewBlocksChunkCount,
+ NiceBytes(FindBlocksStats.NewBlocksChunkByteCount),
+ FindBlocksStats.NewBlocksCount,
- LooseChunksStats.ChunkCount,
- NiceBytes(LooseChunksStats.ChunkByteCount),
+ LooseChunksStats.ChunkCount,
+ NiceBytes(LooseChunksStats.ChunkByteCount),
- NiceTimeSpanMs(BlockArrangeTimer.GetElapsedTimeMs()));
+ NiceTimeSpanMs(BlockArrangeTimer.GetElapsedTimeMs()));
}
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::GenerateBlocks, StepCount);
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::GenerateBlocks, StepCount);
GeneratedBlocks NewBlocks;
if (!NewBlockChunks.empty())
@@ -5506,295 +5468,523 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
uint64_t BlockGenerateTimeUs = GenerateBuildBlocksTimer.GetElapsedTimeUs();
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Generated {} ({}) and uploaded {} ({}) blocks in {}. Generate speed: {}B/sec. Transfer speed {}bits/sec.",
- GenerateBlocksStats.GeneratedBlockCount.load(),
- NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount),
- UploadStats.BlockCount.load(),
- NiceBytes(UploadStats.BlocksBytes.load()),
- NiceTimeSpanMs(BlockGenerateTimeUs / 1000),
- NiceNum(GetBytesPerSecond(GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS,
- GenerateBlocksStats.GeneratedBlockByteCount)),
- NiceNum(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.BlocksBytes * 8)));
+ ZEN_INFO("Generated {} ({}) and uploaded {} ({}) blocks in {}. Generate speed: {}B/sec. Transfer speed {}bits/sec.",
+ GenerateBlocksStats.GeneratedBlockCount.load(),
+ NiceBytes(GenerateBlocksStats.GeneratedBlockByteCount),
+ UploadStats.BlockCount.load(),
+ NiceBytes(UploadStats.BlocksBytes.load()),
+ NiceTimeSpanMs(BlockGenerateTimeUs / 1000),
+ NiceNum(GetBytesPerSecond(GenerateBlocksStats.GenerateBlocksElapsedWallTimeUS,
+ GenerateBlocksStats.GeneratedBlockByteCount)),
+ NiceNum(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.BlocksBytes * 8)));
}
});
GenerateBuildBlocks(LocalContent, LocalLookup, NewBlockChunks, NewBlocks, GenerateBlocksStats, UploadStats);
}
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::BuildPartManifest, StepCount);
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::BuildPartManifest, StepCount);
+
+ BuiltPartManifest Manifest =
+ BuildPartManifestObject(LocalContent, LocalLookup, ChunkController, ReuseBlockIndexes, NewBlocks, LooseChunkIndexes);
+
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadBuildPart, StepCount);
- CbObject PartManifest;
+ Stopwatch PutBuildPartResultTimer;
+ std::pair<IoHash, std::vector<IoHash>> PutBuildPartResult =
+ m_Storage.BuildStorage->PutBuildPart(m_BuildId, Part.PartId, Part.PartName, Manifest.PartManifest);
+ if (!m_Options.IsQuiet)
{
- CbObjectWriter PartManifestWriter;
- Stopwatch ManifestGenerationTimer;
- auto __ = MakeGuard([&]() {
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Generated build part manifest in {} ({})",
- NiceTimeSpanMs(ManifestGenerationTimer.GetElapsedTimeMs()),
- NiceBytes(PartManifestWriter.GetSaveSize()));
- }
- });
+ ZEN_INFO("PutBuildPart took {}, payload size {}. {} attachments are needed.",
+ NiceTimeSpanMs(PutBuildPartResultTimer.GetElapsedTimeMs()),
+ NiceBytes(Manifest.PartManifest.GetSize()),
+ PutBuildPartResult.second.size());
+ }
+ IoHash PartHash = PutBuildPartResult.first;
- PartManifestWriter.BeginObject("chunker"sv);
- {
- PartManifestWriter.AddString("name"sv, ChunkController.GetName());
- PartManifestWriter.AddObject("parameters"sv, ChunkController.GetParameters());
- }
- PartManifestWriter.EndObject(); // chunker
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadAttachments, StepCount);
- std::vector<IoHash> AllChunkBlockHashes;
- std::vector<ChunkBlockDescription> AllChunkBlockDescriptions;
- AllChunkBlockHashes.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size());
- AllChunkBlockDescriptions.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size());
- for (size_t ReuseBlockIndex : ReuseBlockIndexes)
- {
- AllChunkBlockDescriptions.push_back(m_KnownBlocks[ReuseBlockIndex]);
- AllChunkBlockHashes.push_back(m_KnownBlocks[ReuseBlockIndex].BlockHash);
- }
- AllChunkBlockDescriptions.insert(AllChunkBlockDescriptions.end(),
- NewBlocks.BlockDescriptions.begin(),
- NewBlocks.BlockDescriptions.end());
- for (const ChunkBlockDescription& BlockDescription : NewBlocks.BlockDescriptions)
+ std::vector<IoHash> UnknownChunks;
+ if (m_Options.IgnoreExistingBlocks)
+ {
+ if (m_Options.IsVerbose)
{
- AllChunkBlockHashes.push_back(BlockDescription.BlockHash);
+ ZEN_INFO("PutBuildPart uploading all attachments, needs are: {}", FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv));
}
- std::vector<IoHash> AbsoluteChunkHashes;
- if (m_Options.DoExtraContentValidation)
+ std::vector<IoHash> ForceUploadChunkHashes;
+ ForceUploadChunkHashes.reserve(LooseChunkIndexes.size());
+
+ for (uint32_t ChunkIndex : LooseChunkIndexes)
{
- tsl::robin_map<IoHash, size_t, IoHash::Hasher> ChunkHashToAbsoluteChunkIndex;
- AbsoluteChunkHashes.reserve(LocalContent.ChunkedContent.ChunkHashes.size());
- for (uint32_t ChunkIndex : LooseChunkIndexes)
- {
- ChunkHashToAbsoluteChunkIndex.insert({LocalContent.ChunkedContent.ChunkHashes[ChunkIndex], AbsoluteChunkHashes.size()});
- AbsoluteChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]);
- }
- for (const ChunkBlockDescription& Block : AllChunkBlockDescriptions)
- {
- for (const IoHash& ChunkHash : Block.ChunkRawHashes)
- {
- ChunkHashToAbsoluteChunkIndex.insert({ChunkHash, AbsoluteChunkHashes.size()});
- AbsoluteChunkHashes.push_back(ChunkHash);
- }
- }
- for (const IoHash& ChunkHash : LocalContent.ChunkedContent.ChunkHashes)
- {
- ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(ChunkHash)] == ChunkHash);
- ZEN_ASSERT(LocalContent.ChunkedContent.ChunkHashes[LocalLookup.ChunkHashToChunkIndex.at(ChunkHash)] == ChunkHash);
- }
- for (const uint32_t ChunkIndex : LocalContent.ChunkedContent.ChunkOrders)
- {
- ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex])] ==
- LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]);
- ZEN_ASSERT(LocalLookup.ChunkHashToChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]) == ChunkIndex);
- }
+ ForceUploadChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]);
}
- std::vector<uint32_t> AbsoluteChunkOrders = CalculateAbsoluteChunkOrders(LocalContent.ChunkedContent.ChunkHashes,
- LocalContent.ChunkedContent.ChunkOrders,
- LocalLookup.ChunkHashToChunkIndex,
- LooseChunkIndexes,
- AllChunkBlockDescriptions);
- if (m_Options.DoExtraContentValidation)
+ for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockHeaders.size(); BlockIndex++)
{
- for (uint32_t ChunkOrderIndex = 0; ChunkOrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); ChunkOrderIndex++)
+ if (NewBlocks.BlockHeaders[BlockIndex])
{
- uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[ChunkOrderIndex];
- uint32_t AbsoluteChunkIndex = AbsoluteChunkOrders[ChunkOrderIndex];
- const IoHash& LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex];
- const IoHash& AbsoluteChunkHash = AbsoluteChunkHashes[AbsoluteChunkIndex];
- ZEN_ASSERT(LocalChunkHash == AbsoluteChunkHash);
+ // Block was not uploaded during generation
+ ForceUploadChunkHashes.push_back(NewBlocks.BlockDescriptions[BlockIndex].BlockHash);
}
}
-
- WriteBuildContentToCompactBinary(PartManifestWriter,
- LocalContent.Platform,
- LocalContent.Paths,
- LocalContent.RawHashes,
- LocalContent.RawSizes,
- LocalContent.Attributes,
- LocalContent.ChunkedContent.SequenceRawHashes,
- LocalContent.ChunkedContent.ChunkCounts,
- LocalContent.ChunkedContent.ChunkHashes,
- LocalContent.ChunkedContent.ChunkRawSizes,
- AbsoluteChunkOrders,
- LooseChunkIndexes,
- AllChunkBlockHashes);
-
- if (m_Options.DoExtraContentValidation)
+ UploadAttachmentBatch(ForceUploadChunkHashes,
+ UnknownChunks,
+ LocalContent,
+ LocalLookup,
+ NewBlockChunks,
+ NewBlocks,
+ LooseChunkIndexes,
+ UploadStats,
+ LooseChunksStats);
+ }
+ else if (!PutBuildPartResult.second.empty())
+ {
+ if (m_Options.IsVerbose)
{
- ChunkedFolderContent VerifyFolderContent;
-
- std::vector<uint32_t> OutAbsoluteChunkOrders;
- std::vector<IoHash> OutLooseChunkHashes;
- std::vector<uint64_t> OutLooseChunkRawSizes;
- std::vector<IoHash> OutBlockRawHashes;
- ReadBuildContentFromCompactBinary(PartManifestWriter.Save(),
- VerifyFolderContent.Platform,
- VerifyFolderContent.Paths,
- VerifyFolderContent.RawHashes,
- VerifyFolderContent.RawSizes,
- VerifyFolderContent.Attributes,
- VerifyFolderContent.ChunkedContent.SequenceRawHashes,
- VerifyFolderContent.ChunkedContent.ChunkCounts,
- OutAbsoluteChunkOrders,
- OutLooseChunkHashes,
- OutLooseChunkRawSizes,
- OutBlockRawHashes);
- ZEN_ASSERT(OutBlockRawHashes == AllChunkBlockHashes);
+ ZEN_INFO("PutBuildPart needs attachments: {}", FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv));
+ }
+ UploadAttachmentBatch(PutBuildPartResult.second,
+ UnknownChunks,
+ LocalContent,
+ LocalLookup,
+ NewBlockChunks,
+ NewBlocks,
+ LooseChunkIndexes,
+ UploadStats,
+ LooseChunksStats);
+ }
+
+ FinalizeBuildPartWithRetries(Part,
+ PartHash,
+ UnknownChunks,
+ LocalContent,
+ LocalLookup,
+ NewBlockChunks,
+ NewBlocks,
+ LooseChunkIndexes,
+ UploadStats,
+ LooseChunksStats);
- for (uint32_t OrderIndex = 0; OrderIndex < OutAbsoluteChunkOrders.size(); OrderIndex++)
- {
- uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex];
- const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex];
+ if (!NewBlocks.BlockDescriptions.empty() && !m_AbortFlag)
+ {
+ UploadMissingBlockMetadata(NewBlocks, UploadStats);
+ // The newly generated blocks are now known blocks so the next part upload can use those blocks as well
+ m_KnownBlocks.insert(m_KnownBlocks.end(), NewBlocks.BlockDescriptions.begin(), NewBlocks.BlockDescriptions.end());
+ }
- uint32_t VerifyChunkIndex = OutAbsoluteChunkOrders[OrderIndex];
- const IoHash VerifyChunkHash = AbsoluteChunkHashes[VerifyChunkIndex];
+ m_Progress.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::PutBuildPartStats, StepCount);
- ZEN_ASSERT(LocalChunkHash == VerifyChunkHash);
- }
+ m_Storage.BuildStorage->PutBuildPartStats(
+ m_BuildId,
+ Part.PartId,
+ {{"totalSize", double(Part.LocalFolderScanStats.FoundFileByteCount.load())},
+ {"reusedRatio", AcceptedByteCountPercent / 100.0},
+ {"reusedBlockCount", double(FindBlocksStats.AcceptedBlockCount)},
+ {"reusedBlockByteCount", double(ReuseBlocksStats.AcceptedRawByteCount)},
+ {"newBlockCount", double(FindBlocksStats.NewBlocksCount)},
+ {"newBlockByteCount", double(FindBlocksStats.NewBlocksChunkByteCount)},
+ {"uploadedCount", double(UploadStats.BlockCount.load() + UploadStats.ChunkCount.load())},
+ {"uploadedByteCount", double(UploadStats.BlocksBytes.load() + UploadStats.ChunksBytes.load())},
+ {"uploadedBytesPerSec",
+ double(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.ChunksBytes + UploadStats.BlocksBytes))},
+ {"elapsedTimeSec", double(UploadTimer.GetElapsedTimeMs() / 1000.0)}});
- CalculateLocalChunkOrders(OutAbsoluteChunkOrders,
- OutLooseChunkHashes,
- OutLooseChunkRawSizes,
- AllChunkBlockDescriptions,
- VerifyFolderContent.ChunkedContent.ChunkHashes,
- VerifyFolderContent.ChunkedContent.ChunkRawSizes,
- VerifyFolderContent.ChunkedContent.ChunkOrders,
- m_Options.DoExtraContentValidation);
+ m_LocalFolderScanStats += Part.LocalFolderScanStats;
+ m_ChunkingStats += ChunkingStats;
+ m_FindBlocksStats += FindBlocksStats;
+ m_ReuseBlocksStats += ReuseBlocksStats;
+ m_UploadStats += UploadStats;
+ m_GenerateBlocksStats += GenerateBlocksStats;
+ m_LooseChunksStats += LooseChunksStats;
+}
- ZEN_ASSERT(LocalContent.Paths == VerifyFolderContent.Paths);
- ZEN_ASSERT(LocalContent.RawHashes == VerifyFolderContent.RawHashes);
- ZEN_ASSERT(LocalContent.RawSizes == VerifyFolderContent.RawSizes);
- ZEN_ASSERT(LocalContent.Attributes == VerifyFolderContent.Attributes);
- ZEN_ASSERT(LocalContent.ChunkedContent.SequenceRawHashes == VerifyFolderContent.ChunkedContent.SequenceRawHashes);
- ZEN_ASSERT(LocalContent.ChunkedContent.ChunkCounts == VerifyFolderContent.ChunkedContent.ChunkCounts);
+ChunkedFolderContent
+BuildsOperationUploadFolder::ScanPartContent(const UploadPart& Part,
+ ChunkingController& ChunkController,
+ ChunkingCache& ChunkCache,
+ ChunkingStatistics& ChunkingStats)
+{
+ Stopwatch ScanTimer;
- for (uint32_t OrderIndex = 0; OrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); OrderIndex++)
- {
- uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex];
- const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex];
- uint64_t LocalChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[LocalChunkIndex];
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Scan Folder");
- uint32_t VerifyChunkIndex = VerifyFolderContent.ChunkedContent.ChunkOrders[OrderIndex];
- const IoHash VerifyChunkHash = VerifyFolderContent.ChunkedContent.ChunkHashes[VerifyChunkIndex];
- uint64_t VerifyChunkRawSize = VerifyFolderContent.ChunkedContent.ChunkRawSizes[VerifyChunkIndex];
+ FilteredRate FilteredBytesHashed;
+ FilteredBytesHashed.Start();
+ ChunkedFolderContent LocalContent = ChunkFolderContent(
+ ChunkingStats,
+ m_IOWorkerPool,
+ m_Path,
+ Part.Content,
+ ChunkController,
+ ChunkCache,
+ m_Progress.GetProgressUpdateDelayMS(),
+ [&](bool IsAborted, bool IsPaused, std::ptrdiff_t) {
+ FilteredBytesHashed.Update(ChunkingStats.BytesHashed.load());
+ std::string Details = fmt::format("{}/{} ({}/{}, {}B/s) scanned, {} ({}) chunks found",
+ ChunkingStats.FilesProcessed.load(),
+ Part.Content.Paths.size(),
+ NiceBytes(ChunkingStats.BytesHashed.load()),
+ NiceBytes(Part.TotalRawSize),
+ NiceNum(FilteredBytesHashed.GetCurrent()),
+ ChunkingStats.UniqueChunksFound.load(),
+ NiceBytes(ChunkingStats.UniqueBytesFound.load()));
+ ProgressBar->UpdateState({.Task = "Scanning files ",
+ .Details = Details,
+ .TotalCount = Part.TotalRawSize,
+ .RemainingCount = Part.TotalRawSize - ChunkingStats.BytesHashed.load(),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
+ },
+ m_AbortFlag,
+ m_PauseFlag);
+ FilteredBytesHashed.Stop();
+ ProgressBar->Finish();
+ if (m_AbortFlag)
+ {
+ return LocalContent;
+ }
- ZEN_ASSERT(LocalChunkHash == VerifyChunkHash);
- ZEN_ASSERT(LocalChunkRawSize == VerifyChunkRawSize);
- }
- }
- PartManifest = PartManifestWriter.Save();
+ if (!m_Options.IsQuiet)
+ {
+ ZEN_INFO("Found {} ({}) files divided into {} ({}) unique chunks in '{}' in {}. Average hash rate {}B/sec",
+ Part.Content.Paths.size(),
+ NiceBytes(Part.TotalRawSize),
+ ChunkingStats.UniqueChunksFound.load(),
+ NiceBytes(ChunkingStats.UniqueBytesFound.load()),
+ m_Path,
+ NiceTimeSpanMs(ScanTimer.GetElapsedTimeMs()),
+ NiceNum(GetBytesPerSecond(ChunkingStats.ElapsedWallTimeUS, ChunkingStats.BytesHashed)));
}
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadBuildPart, StepCount);
+ return LocalContent;
+}
+
+void
+BuildsOperationUploadFolder::ConsumePrepareBuildResult()
+{
+ const PrepareBuildResult PrepBuildResult = m_PrepBuildResultFuture.get();
+
+ m_FindBlocksStats.FindBlockTimeMS = PrepBuildResult.ElapsedTimeMs;
+ m_FindBlocksStats.FoundBlockCount = PrepBuildResult.KnownBlocks.size();
- Stopwatch PutBuildPartResultTimer;
- std::pair<IoHash, std::vector<IoHash>> PutBuildPartResult =
- m_Storage.BuildStorage->PutBuildPart(m_BuildId, Part.PartId, Part.PartName, PartManifest);
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "PutBuildPart took {}, payload size {}. {} attachments are needed.",
- NiceTimeSpanMs(PutBuildPartResultTimer.GetElapsedTimeMs()),
- NiceBytes(PartManifest.GetSize()),
- PutBuildPartResult.second.size());
+ ZEN_INFO("Build prepare took {}. {} took {}, payload size {}{}",
+ NiceTimeSpanMs(PrepBuildResult.ElapsedTimeMs),
+ m_CreateBuild ? "PutBuild" : "GetBuild",
+ NiceTimeSpanMs(PrepBuildResult.PrepareBuildTimeMs),
+ NiceBytes(PrepBuildResult.PayloadSize),
+ m_Options.IgnoreExistingBlocks ? ""
+ : fmt::format(". Found {} blocks in {}",
+ PrepBuildResult.KnownBlocks.size(),
+ NiceTimeSpanMs(PrepBuildResult.FindBlocksTimeMs)));
}
- IoHash PartHash = PutBuildPartResult.first;
- auto UploadAttachments =
- [this, &LooseChunksStats, &UploadStats, &LocalContent, &LocalLookup, &NewBlockChunks, &NewBlocks, &LooseChunkIndexes](
- std::span<IoHash> RawHashes,
- std::vector<IoHash>& OutUnknownChunks) {
- if (!m_AbortFlag)
- {
- UploadStatistics TempUploadStats;
- LooseChunksStatistics TempLooseChunksStats;
-
- Stopwatch TempUploadTimer;
- auto __ = MakeGuard([&]() {
- if (!m_Options.IsQuiet)
- {
- uint64_t TempChunkUploadTimeUs = TempUploadTimer.GetElapsedTimeUs();
- ZEN_OPERATION_LOG_INFO(
- m_LogOutput,
- "Uploaded {} ({}) blocks. "
- "Compressed {} ({} {}B/s) and uploaded {} ({}) chunks. "
- "Transferred {} ({}bits/s) in {}",
- TempUploadStats.BlockCount.load(),
- NiceBytes(TempUploadStats.BlocksBytes),
-
- TempLooseChunksStats.CompressedChunkCount.load(),
- NiceBytes(TempLooseChunksStats.CompressedChunkBytes.load()),
- NiceNum(GetBytesPerSecond(TempLooseChunksStats.CompressChunksElapsedWallTimeUS,
- TempLooseChunksStats.ChunkByteCount)),
- TempUploadStats.ChunkCount.load(),
- NiceBytes(TempUploadStats.ChunksBytes),
-
- NiceBytes(TempUploadStats.BlocksBytes + TempUploadStats.ChunksBytes),
- NiceNum(GetBytesPerSecond(TempUploadStats.ElapsedWallTimeUS, TempUploadStats.ChunksBytes * 8)),
- NiceTimeSpanMs(TempChunkUploadTimeUs / 1000));
- }
- });
- UploadPartBlobs(LocalContent,
- LocalLookup,
- RawHashes,
- NewBlockChunks,
- NewBlocks,
- LooseChunkIndexes,
- m_LargeAttachmentSize,
- TempUploadStats,
- TempLooseChunksStats,
- OutUnknownChunks);
- UploadStats += TempUploadStats;
- LooseChunksStats += TempLooseChunksStats;
- }
- };
+ m_PreferredMultipartChunkSize = PrepBuildResult.PreferredMultipartChunkSize;
+ m_LargeAttachmentSize = m_Options.AllowMultiparts ? m_PreferredMultipartChunkSize * 4u : (std::uint64_t)-1;
+ m_KnownBlocks = std::move(PrepBuildResult.KnownBlocks);
+}
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::UploadAttachments, StepCount);
+void
+BuildsOperationUploadFolder::ClassifyChunksByBlockEligibility(const ChunkedFolderContent& LocalContent,
+ std::vector<uint32_t>& OutLooseChunkIndexes,
+ std::vector<uint32_t>& OutNewBlockChunkIndexes,
+ std::vector<size_t>& OutReuseBlockIndexes,
+ LooseChunksStatistics& LooseChunksStats,
+ FindBlocksStatistics& FindBlocksStats,
+ ReuseBlocksStatistics& ReuseBlocksStats)
+{
+ const bool EnableBlocks = true;
+ std::vector<std::uint32_t> BlockChunkIndexes;
+ for (uint32_t ChunkIndex = 0; ChunkIndex < LocalContent.ChunkedContent.ChunkHashes.size(); ChunkIndex++)
+ {
+ const uint64_t ChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[ChunkIndex];
+ if (!EnableBlocks || ChunkRawSize == 0 || ChunkRawSize > m_Options.BlockParameters.MaxChunkEmbedSize)
+ {
+ OutLooseChunkIndexes.push_back(ChunkIndex);
+ LooseChunksStats.ChunkByteCount += ChunkRawSize;
+ }
+ else
+ {
+ BlockChunkIndexes.push_back(ChunkIndex);
+ FindBlocksStats.PotentialChunkByteCount += ChunkRawSize;
+ }
+ }
+ FindBlocksStats.PotentialChunkCount += BlockChunkIndexes.size();
+ LooseChunksStats.ChunkCount = OutLooseChunkIndexes.size();
- std::vector<IoHash> UnknownChunks;
if (m_Options.IgnoreExistingBlocks)
{
- if (m_Options.IsVerbose)
+ if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "PutBuildPart uploading all attachments, needs are: {}",
- FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv));
+ ZEN_INFO("Ignoring any existing blocks in store");
}
+ OutNewBlockChunkIndexes = std::move(BlockChunkIndexes);
+ return;
+ }
- std::vector<IoHash> ForceUploadChunkHashes;
- ForceUploadChunkHashes.reserve(LooseChunkIndexes.size());
+ OutReuseBlockIndexes = FindReuseBlocks(Log(),
+ m_Options.BlockReuseMinPercentLimit,
+ m_Options.IsVerbose,
+ ReuseBlocksStats,
+ m_KnownBlocks,
+ LocalContent.ChunkedContent.ChunkHashes,
+ BlockChunkIndexes,
+ OutNewBlockChunkIndexes);
+ FindBlocksStats.AcceptedBlockCount += OutReuseBlockIndexes.size();
- for (uint32_t ChunkIndex : LooseChunkIndexes)
+ for (const ChunkBlockDescription& Description : m_KnownBlocks)
+ {
+ for (uint32_t ChunkRawLength : Description.ChunkRawLengths)
{
- ForceUploadChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]);
+ FindBlocksStats.FoundBlockByteCount += ChunkRawLength;
}
+ FindBlocksStats.FoundBlockChunkCount += Description.ChunkRawHashes.size();
+ }
+}
- for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockHeaders.size(); BlockIndex++)
+BuildsOperationUploadFolder::BuiltPartManifest
+BuildsOperationUploadFolder::BuildPartManifestObject(const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ ChunkingController& ChunkController,
+ std::span<const size_t> ReuseBlockIndexes,
+ const GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes)
+{
+ BuiltPartManifest Result;
+
+ CbObjectWriter PartManifestWriter;
+ Stopwatch ManifestGenerationTimer;
+ auto __ = MakeGuard([&]() {
+ if (!m_Options.IsQuiet)
+ {
+ ZEN_INFO("Generated build part manifest in {} ({})",
+ NiceTimeSpanMs(ManifestGenerationTimer.GetElapsedTimeMs()),
+ NiceBytes(PartManifestWriter.GetSaveSize()));
+ }
+ });
+
+ PartManifestWriter.BeginObject("chunker"sv);
+ {
+ PartManifestWriter.AddString("name"sv, ChunkController.GetName());
+ PartManifestWriter.AddObject("parameters"sv, ChunkController.GetParameters());
+ }
+ PartManifestWriter.EndObject(); // chunker
+
+ Result.AllChunkBlockHashes.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size());
+ Result.AllChunkBlockDescriptions.reserve(ReuseBlockIndexes.size() + NewBlocks.BlockDescriptions.size());
+ for (size_t ReuseBlockIndex : ReuseBlockIndexes)
+ {
+ Result.AllChunkBlockDescriptions.push_back(m_KnownBlocks[ReuseBlockIndex]);
+ Result.AllChunkBlockHashes.push_back(m_KnownBlocks[ReuseBlockIndex].BlockHash);
+ }
+ Result.AllChunkBlockDescriptions.insert(Result.AllChunkBlockDescriptions.end(),
+ NewBlocks.BlockDescriptions.begin(),
+ NewBlocks.BlockDescriptions.end());
+ for (const ChunkBlockDescription& BlockDescription : NewBlocks.BlockDescriptions)
+ {
+ Result.AllChunkBlockHashes.push_back(BlockDescription.BlockHash);
+ }
+
+ std::vector<IoHash> AbsoluteChunkHashes;
+ if (m_Options.DoExtraContentValidation)
+ {
+ tsl::robin_map<IoHash, size_t, IoHash::Hasher> ChunkHashToAbsoluteChunkIndex;
+ AbsoluteChunkHashes.reserve(LocalContent.ChunkedContent.ChunkHashes.size());
+ for (uint32_t ChunkIndex : LooseChunkIndexes)
{
- if (NewBlocks.BlockHeaders[BlockIndex])
+ ChunkHashToAbsoluteChunkIndex.insert({LocalContent.ChunkedContent.ChunkHashes[ChunkIndex], AbsoluteChunkHashes.size()});
+ AbsoluteChunkHashes.push_back(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]);
+ }
+ for (const ChunkBlockDescription& Block : Result.AllChunkBlockDescriptions)
+ {
+ for (const IoHash& ChunkHash : Block.ChunkRawHashes)
{
- // Block was not uploaded during generation
- ForceUploadChunkHashes.push_back(NewBlocks.BlockDescriptions[BlockIndex].BlockHash);
+ ChunkHashToAbsoluteChunkIndex.insert({ChunkHash, AbsoluteChunkHashes.size()});
+ AbsoluteChunkHashes.push_back(ChunkHash);
}
}
- UploadAttachments(ForceUploadChunkHashes, UnknownChunks);
+ for (const IoHash& ChunkHash : LocalContent.ChunkedContent.ChunkHashes)
+ {
+ ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(ChunkHash)] == ChunkHash);
+ ZEN_ASSERT(LocalContent.ChunkedContent.ChunkHashes[LocalLookup.ChunkHashToChunkIndex.at(ChunkHash)] == ChunkHash);
+ }
+ for (const uint32_t ChunkIndex : LocalContent.ChunkedContent.ChunkOrders)
+ {
+ ZEN_ASSERT(AbsoluteChunkHashes[ChunkHashToAbsoluteChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex])] ==
+ LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]);
+ ZEN_ASSERT(LocalLookup.ChunkHashToChunkIndex.at(LocalContent.ChunkedContent.ChunkHashes[ChunkIndex]) == ChunkIndex);
+ }
}
- else if (!PutBuildPartResult.second.empty())
+
+ std::vector<uint32_t> AbsoluteChunkOrders = CalculateAbsoluteChunkOrders(LocalContent.ChunkedContent.ChunkHashes,
+ LocalContent.ChunkedContent.ChunkOrders,
+ LocalLookup.ChunkHashToChunkIndex,
+ LooseChunkIndexes,
+ Result.AllChunkBlockDescriptions);
+
+ if (m_Options.DoExtraContentValidation)
{
- if (m_Options.IsVerbose)
+ for (uint32_t ChunkOrderIndex = 0; ChunkOrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); ChunkOrderIndex++)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "PutBuildPart needs attachments: {}",
- FormatArray<IoHash>(PutBuildPartResult.second, "\n "sv));
+ uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[ChunkOrderIndex];
+ uint32_t AbsoluteChunkIndex = AbsoluteChunkOrders[ChunkOrderIndex];
+ const IoHash& LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex];
+ const IoHash& AbsoluteChunkHash = AbsoluteChunkHashes[AbsoluteChunkIndex];
+ ZEN_ASSERT(LocalChunkHash == AbsoluteChunkHash);
}
- UploadAttachments(PutBuildPartResult.second, UnknownChunks);
}
+ WriteBuildContentToCompactBinary(PartManifestWriter,
+ LocalContent.Platform,
+ LocalContent.Paths,
+ LocalContent.RawHashes,
+ LocalContent.RawSizes,
+ LocalContent.Attributes,
+ LocalContent.ChunkedContent.SequenceRawHashes,
+ LocalContent.ChunkedContent.ChunkCounts,
+ LocalContent.ChunkedContent.ChunkHashes,
+ LocalContent.ChunkedContent.ChunkRawSizes,
+ AbsoluteChunkOrders,
+ LooseChunkIndexes,
+ Result.AllChunkBlockHashes);
+
+ if (m_Options.DoExtraContentValidation)
+ {
+ ChunkedFolderContent VerifyFolderContent;
+
+ std::vector<uint32_t> OutAbsoluteChunkOrders;
+ std::vector<IoHash> OutLooseChunkHashes;
+ std::vector<uint64_t> OutLooseChunkRawSizes;
+ std::vector<IoHash> OutBlockRawHashes;
+ ReadBuildContentFromCompactBinary(PartManifestWriter.Save(),
+ VerifyFolderContent.Platform,
+ VerifyFolderContent.Paths,
+ VerifyFolderContent.RawHashes,
+ VerifyFolderContent.RawSizes,
+ VerifyFolderContent.Attributes,
+ VerifyFolderContent.ChunkedContent.SequenceRawHashes,
+ VerifyFolderContent.ChunkedContent.ChunkCounts,
+ OutAbsoluteChunkOrders,
+ OutLooseChunkHashes,
+ OutLooseChunkRawSizes,
+ OutBlockRawHashes);
+ ZEN_ASSERT(OutBlockRawHashes == Result.AllChunkBlockHashes);
+
+ for (uint32_t OrderIndex = 0; OrderIndex < OutAbsoluteChunkOrders.size(); OrderIndex++)
+ {
+ uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex];
+ const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex];
+
+ uint32_t VerifyChunkIndex = OutAbsoluteChunkOrders[OrderIndex];
+ const IoHash VerifyChunkHash = AbsoluteChunkHashes[VerifyChunkIndex];
+
+ ZEN_ASSERT(LocalChunkHash == VerifyChunkHash);
+ }
+
+ CalculateLocalChunkOrders(OutAbsoluteChunkOrders,
+ OutLooseChunkHashes,
+ OutLooseChunkRawSizes,
+ Result.AllChunkBlockDescriptions,
+ VerifyFolderContent.ChunkedContent.ChunkHashes,
+ VerifyFolderContent.ChunkedContent.ChunkRawSizes,
+ VerifyFolderContent.ChunkedContent.ChunkOrders,
+ m_Options.DoExtraContentValidation);
+
+ ZEN_ASSERT(LocalContent.Paths == VerifyFolderContent.Paths);
+ ZEN_ASSERT(LocalContent.RawHashes == VerifyFolderContent.RawHashes);
+ ZEN_ASSERT(LocalContent.RawSizes == VerifyFolderContent.RawSizes);
+ ZEN_ASSERT(LocalContent.Attributes == VerifyFolderContent.Attributes);
+ ZEN_ASSERT(LocalContent.ChunkedContent.SequenceRawHashes == VerifyFolderContent.ChunkedContent.SequenceRawHashes);
+ ZEN_ASSERT(LocalContent.ChunkedContent.ChunkCounts == VerifyFolderContent.ChunkedContent.ChunkCounts);
+
+ for (uint32_t OrderIndex = 0; OrderIndex < LocalContent.ChunkedContent.ChunkOrders.size(); OrderIndex++)
+ {
+ uint32_t LocalChunkIndex = LocalContent.ChunkedContent.ChunkOrders[OrderIndex];
+ const IoHash LocalChunkHash = LocalContent.ChunkedContent.ChunkHashes[LocalChunkIndex];
+ uint64_t LocalChunkRawSize = LocalContent.ChunkedContent.ChunkRawSizes[LocalChunkIndex];
+
+ uint32_t VerifyChunkIndex = VerifyFolderContent.ChunkedContent.ChunkOrders[OrderIndex];
+ const IoHash VerifyChunkHash = VerifyFolderContent.ChunkedContent.ChunkHashes[VerifyChunkIndex];
+ uint64_t VerifyChunkRawSize = VerifyFolderContent.ChunkedContent.ChunkRawSizes[VerifyChunkIndex];
+
+ ZEN_ASSERT(LocalChunkHash == VerifyChunkHash);
+ ZEN_ASSERT(LocalChunkRawSize == VerifyChunkRawSize);
+ }
+ }
+
+ Result.PartManifest = PartManifestWriter.Save();
+ return Result;
+}
+
+void
+BuildsOperationUploadFolder::UploadAttachmentBatch(std::span<IoHash> RawHashes,
+ std::vector<IoHash>& OutUnknownChunks,
+ const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks,
+ GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes,
+ UploadStatistics& UploadStats,
+ LooseChunksStatistics& LooseChunksStats)
+{
+ if (m_AbortFlag)
+ {
+ return;
+ }
+
+ UploadStatistics TempUploadStats;
+ LooseChunksStatistics TempLooseChunksStats;
+
+ Stopwatch TempUploadTimer;
+ auto __ = MakeGuard([&]() {
+ if (!m_Options.IsQuiet)
+ {
+ uint64_t TempChunkUploadTimeUs = TempUploadTimer.GetElapsedTimeUs();
+ ZEN_INFO(
+ "Uploaded {} ({}) blocks. "
+ "Compressed {} ({} {}B/s) and uploaded {} ({}) chunks. "
+ "Transferred {} ({}bits/s) in {}",
+ TempUploadStats.BlockCount.load(),
+ NiceBytes(TempUploadStats.BlocksBytes),
+
+ TempLooseChunksStats.CompressedChunkCount.load(),
+ NiceBytes(TempLooseChunksStats.CompressedChunkBytes.load()),
+ NiceNum(GetBytesPerSecond(TempLooseChunksStats.CompressChunksElapsedWallTimeUS, TempLooseChunksStats.ChunkByteCount)),
+ TempUploadStats.ChunkCount.load(),
+ NiceBytes(TempUploadStats.ChunksBytes),
+
+ NiceBytes(TempUploadStats.BlocksBytes + TempUploadStats.ChunksBytes),
+ NiceNum(GetBytesPerSecond(TempUploadStats.ElapsedWallTimeUS, TempUploadStats.ChunksBytes * 8)),
+ NiceTimeSpanMs(TempChunkUploadTimeUs / 1000));
+ }
+ });
+ UploadPartBlobs(LocalContent,
+ LocalLookup,
+ RawHashes,
+ NewBlockChunks,
+ NewBlocks,
+ LooseChunkIndexes,
+ m_LargeAttachmentSize,
+ TempUploadStats,
+ TempLooseChunksStats,
+ OutUnknownChunks);
+ UploadStats += TempUploadStats;
+ LooseChunksStats += TempLooseChunksStats;
+}
+
+void
+BuildsOperationUploadFolder::FinalizeBuildPartWithRetries(const UploadPart& Part,
+ const IoHash& PartHash,
+ std::vector<IoHash>& InOutUnknownChunks,
+ const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks,
+ GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes,
+ UploadStatistics& UploadStats,
+ LooseChunksStatistics& LooseChunksStats)
+{
auto BuildUnkownChunksResponse = [](const std::vector<IoHash>& UnknownChunks, bool WillRetry) {
return fmt::format(
"The following build blobs was reported as needed for upload but was reported as existing at the start of the "
@@ -5803,9 +5993,9 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
FormatArray<IoHash>(UnknownChunks, "\n "sv));
};
- if (!UnknownChunks.empty())
+ if (!InOutUnknownChunks.empty())
{
- ZEN_OPERATION_LOG_WARN(m_LogOutput, "{}", BuildUnkownChunksResponse(UnknownChunks, /*WillRetry*/ true));
+ ZEN_WARN("{}", BuildUnkownChunksResponse(InOutUnknownChunks, /*WillRetry*/ true));
}
uint32_t FinalizeBuildPartRetryCount = 5;
@@ -5815,10 +6005,9 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
std::vector<IoHash> Needs = m_Storage.BuildStorage->FinalizeBuildPart(m_BuildId, Part.PartId, PartHash);
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "FinalizeBuildPart took {}. {} attachments are missing.",
- NiceTimeSpanMs(FinalizeBuildPartTimer.GetElapsedTimeMs()),
- Needs.size());
+ ZEN_INFO("FinalizeBuildPart took {}. {} attachments are missing.",
+ NiceTimeSpanMs(FinalizeBuildPartTimer.GetElapsedTimeMs()),
+ Needs.size());
}
if (Needs.empty())
{
@@ -5826,12 +6015,20 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
}
if (m_Options.IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "FinalizeBuildPart needs attachments: {}", FormatArray<IoHash>(Needs, "\n "sv));
+ ZEN_INFO("FinalizeBuildPart needs attachments: {}", FormatArray<IoHash>(Needs, "\n "sv));
}
std::vector<IoHash> RetryUnknownChunks;
- UploadAttachments(Needs, RetryUnknownChunks);
- if (RetryUnknownChunks == UnknownChunks)
+ UploadAttachmentBatch(Needs,
+ RetryUnknownChunks,
+ LocalContent,
+ LocalLookup,
+ NewBlockChunks,
+ NewBlocks,
+ LooseChunkIndexes,
+ UploadStats,
+ LooseChunksStats);
+ if (RetryUnknownChunks == InOutUnknownChunks)
{
if (FinalizeBuildPartRetryCount > 0)
{
@@ -5841,100 +6038,68 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController
}
else
{
- UnknownChunks = RetryUnknownChunks;
- ZEN_OPERATION_LOG_WARN(m_LogOutput,
- "{}",
- BuildUnkownChunksResponse(UnknownChunks, /*WillRetry*/ FinalizeBuildPartRetryCount != 0));
+ InOutUnknownChunks = RetryUnknownChunks;
+ ZEN_WARN("{}", BuildUnkownChunksResponse(InOutUnknownChunks, /*WillRetry*/ FinalizeBuildPartRetryCount != 0));
}
}
- if (!UnknownChunks.empty())
+ if (!InOutUnknownChunks.empty())
{
- throw std::runtime_error(BuildUnkownChunksResponse(UnknownChunks, /*WillRetry*/ false));
+ throw std::runtime_error(BuildUnkownChunksResponse(InOutUnknownChunks, /*WillRetry*/ false));
}
+}
- if (!NewBlocks.BlockDescriptions.empty() && !m_AbortFlag)
- {
- uint64_t UploadBlockMetadataCount = 0;
- Stopwatch UploadBlockMetadataTimer;
+void
+BuildsOperationUploadFolder::UploadMissingBlockMetadata(GeneratedBlocks& NewBlocks, UploadStatistics& UploadStats)
+{
+ uint64_t UploadBlockMetadataCount = 0;
+ Stopwatch UploadBlockMetadataTimer;
- uint32_t FailedMetadataUploadCount = 1;
- int32_t MetadataUploadRetryCount = 3;
- while ((MetadataUploadRetryCount-- > 0) && (FailedMetadataUploadCount > 0))
+ uint32_t FailedMetadataUploadCount = 1;
+ int32_t MetadataUploadRetryCount = 3;
+ while ((MetadataUploadRetryCount-- > 0) && (FailedMetadataUploadCount > 0))
+ {
+ FailedMetadataUploadCount = 0;
+ for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockDescriptions.size(); BlockIndex++)
{
- FailedMetadataUploadCount = 0;
- for (size_t BlockIndex = 0; BlockIndex < NewBlocks.BlockDescriptions.size(); BlockIndex++)
+ if (m_AbortFlag)
{
- if (m_AbortFlag)
+ break;
+ }
+ const IoHash& BlockHash = NewBlocks.BlockDescriptions[BlockIndex].BlockHash;
+ if (!NewBlocks.MetaDataHasBeenUploaded[BlockIndex])
+ {
+ const CbObject BlockMetaData =
+ BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]);
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- break;
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
}
- const IoHash& BlockHash = NewBlocks.BlockDescriptions[BlockIndex].BlockHash;
- if (!NewBlocks.MetaDataHasBeenUploaded[BlockIndex])
+ bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
+ if (MetadataSucceeded)
{
- const CbObject BlockMetaData =
- BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]);
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
- }
- bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
- if (MetadataSucceeded)
- {
- UploadStats.BlocksBytes += BlockMetaData.GetSize();
- NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
- UploadBlockMetadataCount++;
- }
- else
- {
- FailedMetadataUploadCount++;
- }
+ UploadStats.BlocksBytes += BlockMetaData.GetSize();
+ NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
+ UploadBlockMetadataCount++;
+ }
+ else
+ {
+ FailedMetadataUploadCount++;
}
}
}
- if (UploadBlockMetadataCount > 0)
+ }
+ if (UploadBlockMetadataCount > 0)
+ {
+ uint64_t ElapsedUS = UploadBlockMetadataTimer.GetElapsedTimeUs();
+ UploadStats.ElapsedWallTimeUS += ElapsedUS;
+ if (!m_Options.IsQuiet)
{
- uint64_t ElapsedUS = UploadBlockMetadataTimer.GetElapsedTimeUs();
- UploadStats.ElapsedWallTimeUS += ElapsedUS;
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded metadata for {} blocks in {}",
- UploadBlockMetadataCount,
- NiceTimeSpanMs(ElapsedUS / 1000));
- }
+ ZEN_INFO("Uploaded metadata for {} blocks in {}", UploadBlockMetadataCount, NiceTimeSpanMs(ElapsedUS / 1000));
}
-
- // The newly generated blocks are now known blocks so the next part upload can use those blocks as well
- m_KnownBlocks.insert(m_KnownBlocks.end(), NewBlocks.BlockDescriptions.begin(), NewBlocks.BlockDescriptions.end());
}
-
- m_LogOutput.SetLogOperationProgress(PartStepOffset + (uint32_t)PartTaskSteps::PutBuildPartStats, StepCount);
-
- m_Storage.BuildStorage->PutBuildPartStats(
- m_BuildId,
- Part.PartId,
- {{"totalSize", double(Part.LocalFolderScanStats.FoundFileByteCount.load())},
- {"reusedRatio", AcceptedByteCountPercent / 100.0},
- {"reusedBlockCount", double(FindBlocksStats.AcceptedBlockCount)},
- {"reusedBlockByteCount", double(ReuseBlocksStats.AcceptedRawByteCount)},
- {"newBlockCount", double(FindBlocksStats.NewBlocksCount)},
- {"newBlockByteCount", double(FindBlocksStats.NewBlocksChunkByteCount)},
- {"uploadedCount", double(UploadStats.BlockCount.load() + UploadStats.ChunkCount.load())},
- {"uploadedByteCount", double(UploadStats.BlocksBytes.load() + UploadStats.ChunksBytes.load())},
- {"uploadedBytesPerSec",
- double(GetBytesPerSecond(UploadStats.ElapsedWallTimeUS, UploadStats.ChunksBytes + UploadStats.BlocksBytes))},
- {"elapsedTimeSec", double(UploadTimer.GetElapsedTimeMs() / 1000.0)}});
-
- m_LocalFolderScanStats += Part.LocalFolderScanStats;
- m_ChunkingStats += ChunkingStats;
- m_FindBlocksStats += FindBlocksStats;
- m_ReuseBlocksStats += ReuseBlocksStats;
- m_UploadStats += UploadStats;
- m_GenerateBlocksStats += GenerateBlocksStats;
- m_LooseChunksStats += LooseChunksStats;
}
void
@@ -5950,463 +6115,490 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
std::vector<IoHash>& OutUnknownChunks)
{
ZEN_TRACE_CPU("UploadPartBlobs");
+
+ UploadPartClassification Classification =
+ ClassifyUploadRawHashes(RawHashes, Content, Lookup, NewBlocks, LooseChunkIndexes, OutUnknownChunks);
+
+ if (Classification.BlockIndexes.empty() && Classification.LooseChunkOrderIndexes.empty())
{
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Upload Blobs"));
- OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr);
+ return;
+ }
- WorkerThreadPool& ReadChunkPool = m_IOWorkerPool;
- WorkerThreadPool& UploadChunkPool = m_NetworkPool;
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Upload Blobs");
+
+ FilteredRate FilteredGenerateBlockBytesPerSecond;
+ FilteredRate FilteredCompressedBytesPerSecond;
+ FilteredRate FilteredUploadedBytesPerSecond;
+
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+
+ std::atomic<size_t> UploadedBlockSize = 0;
+ std::atomic<size_t> UploadedBlockCount = 0;
+ std::atomic<size_t> UploadedRawChunkSize = 0;
+ std::atomic<size_t> UploadedCompressedChunkSize = 0;
+ std::atomic<uint32_t> UploadedChunkCount = 0;
+ std::atomic<uint64_t> GeneratedBlockCount = 0;
+ std::atomic<uint64_t> GeneratedBlockByteCount = 0;
+ std::atomic<uint64_t> QueuedPendingInMemoryBlocksForUpload = 0;
+
+ const size_t UploadBlockCount = Classification.BlockIndexes.size();
+ const uint32_t UploadChunkCount = gsl::narrow<uint32_t>(Classification.LooseChunkOrderIndexes.size());
+ const uint64_t TotalRawSize = Classification.TotalLooseChunksSize + Classification.TotalBlocksSize;
+
+ UploadPartBlobsContext Context{.Work = Work,
+ .ReadChunkPool = m_IOWorkerPool,
+ .UploadChunkPool = m_NetworkPool,
+ .FilteredGenerateBlockBytesPerSecond = FilteredGenerateBlockBytesPerSecond,
+ .FilteredCompressedBytesPerSecond = FilteredCompressedBytesPerSecond,
+ .FilteredUploadedBytesPerSecond = FilteredUploadedBytesPerSecond,
+ .UploadedBlockSize = UploadedBlockSize,
+ .UploadedBlockCount = UploadedBlockCount,
+ .UploadedRawChunkSize = UploadedRawChunkSize,
+ .UploadedCompressedChunkSize = UploadedCompressedChunkSize,
+ .UploadedChunkCount = UploadedChunkCount,
+ .GeneratedBlockCount = GeneratedBlockCount,
+ .GeneratedBlockByteCount = GeneratedBlockByteCount,
+ .QueuedPendingInMemoryBlocksForUpload = QueuedPendingInMemoryBlocksForUpload,
+ .UploadBlockCount = UploadBlockCount,
+ .UploadChunkCount = UploadChunkCount,
+ .LargeAttachmentSize = LargeAttachmentSize,
+ .NewBlocks = NewBlocks,
+ .Content = Content,
+ .Lookup = Lookup,
+ .NewBlockChunks = NewBlockChunks,
+ .LooseChunkIndexes = LooseChunkIndexes,
+ .TempUploadStats = TempUploadStats,
+ .TempLooseChunksStats = TempLooseChunksStats};
+
+ ScheduleBlockGenerationAndUpload(Context, Classification.BlockIndexes);
+ ScheduleLooseChunkCompressionAndUpload(Context, Classification.LooseChunkOrderIndexes);
+
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ ZEN_UNUSED(PendingWork);
+ FilteredCompressedBytesPerSecond.Update(TempLooseChunksStats.CompressedChunkRawBytes.load());
+ FilteredGenerateBlockBytesPerSecond.Update(GeneratedBlockByteCount.load());
+ FilteredUploadedBytesPerSecond.Update(UploadedCompressedChunkSize.load() + UploadedBlockSize.load());
+ uint64_t UploadedRawSize = UploadedRawChunkSize.load() + UploadedBlockSize.load();
+ uint64_t UploadedCompressedSize = UploadedCompressedChunkSize.load() + UploadedBlockSize.load();
+
+ std::string Details = fmt::format(
+ "Compressed {}/{} ({}/{}{}) chunks. "
+ "Uploaded {}/{} ({}/{}) blobs "
+ "({}{})",
+ TempLooseChunksStats.CompressedChunkCount.load(),
+ Classification.LooseChunkOrderIndexes.size(),
+ NiceBytes(TempLooseChunksStats.CompressedChunkRawBytes),
+ NiceBytes(Classification.TotalLooseChunksSize),
+ (TempLooseChunksStats.CompressedChunkCount == Classification.LooseChunkOrderIndexes.size())
+ ? ""
+ : fmt::format(" {}B/s", NiceNum(FilteredCompressedBytesPerSecond.GetCurrent())),
+
+ UploadedBlockCount.load() + UploadedChunkCount.load(),
+ UploadBlockCount + UploadChunkCount,
+ NiceBytes(UploadedRawSize),
+ NiceBytes(TotalRawSize),
+
+ NiceBytes(UploadedCompressedSize),
+ (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
+ ? ""
+ : fmt::format(" {}bits/s", NiceNum(FilteredUploadedBytesPerSecond.GetCurrent())));
+
+ ProgressBar->UpdateState({.Task = "Uploading blobs ",
+ .Details = Details,
+ .TotalCount = gsl::narrow<uint64_t>(TotalRawSize),
+ .RemainingCount = gsl::narrow<uint64_t>(TotalRawSize - UploadedRawSize),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
+ });
- FilteredRate FilteredGenerateBlockBytesPerSecond;
- FilteredRate FilteredCompressedBytesPerSecond;
- FilteredRate FilteredUploadedBytesPerSecond;
+ ZEN_ASSERT(m_AbortFlag || QueuedPendingInMemoryBlocksForUpload.load() == 0);
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ ProgressBar->Finish();
- std::atomic<size_t> UploadedBlockSize = 0;
- std::atomic<size_t> UploadedBlockCount = 0;
- std::atomic<size_t> UploadedRawChunkSize = 0;
- std::atomic<size_t> UploadedCompressedChunkSize = 0;
- std::atomic<uint32_t> UploadedChunkCount = 0;
+ TempUploadStats.ElapsedWallTimeUS += FilteredUploadedBytesPerSecond.GetElapsedTimeUS();
+ TempLooseChunksStats.CompressChunksElapsedWallTimeUS += FilteredCompressedBytesPerSecond.GetElapsedTimeUS();
+}
- tsl::robin_map<uint32_t, uint32_t> ChunkIndexToLooseChunkOrderIndex;
- ChunkIndexToLooseChunkOrderIndex.reserve(LooseChunkIndexes.size());
- for (uint32_t OrderIndex = 0; OrderIndex < LooseChunkIndexes.size(); OrderIndex++)
- {
- ChunkIndexToLooseChunkOrderIndex.insert_or_assign(LooseChunkIndexes[OrderIndex], OrderIndex);
- }
+BuildsOperationUploadFolder::UploadPartClassification
+BuildsOperationUploadFolder::ClassifyUploadRawHashes(std::span<IoHash> RawHashes,
+ const ChunkedFolderContent& Content,
+ const ChunkedContentLookup& Lookup,
+ const GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes,
+ std::vector<IoHash>& OutUnknownChunks)
+{
+ UploadPartClassification Result;
- std::vector<size_t> BlockIndexes;
- std::vector<uint32_t> LooseChunkOrderIndexes;
+ tsl::robin_map<uint32_t, uint32_t> ChunkIndexToLooseChunkOrderIndex;
+ ChunkIndexToLooseChunkOrderIndex.reserve(LooseChunkIndexes.size());
+ for (uint32_t OrderIndex = 0; OrderIndex < LooseChunkIndexes.size(); OrderIndex++)
+ {
+ ChunkIndexToLooseChunkOrderIndex.insert_or_assign(LooseChunkIndexes[OrderIndex], OrderIndex);
+ }
- uint64_t TotalLooseChunksSize = 0;
- uint64_t TotalBlocksSize = 0;
- for (const IoHash& RawHash : RawHashes)
+ for (const IoHash& RawHash : RawHashes)
+ {
+ if (auto It = NewBlocks.BlockHashToBlockIndex.find(RawHash); It != NewBlocks.BlockHashToBlockIndex.end())
{
- if (auto It = NewBlocks.BlockHashToBlockIndex.find(RawHash); It != NewBlocks.BlockHashToBlockIndex.end())
- {
- BlockIndexes.push_back(It->second);
- TotalBlocksSize += NewBlocks.BlockSizes[It->second];
- }
- else if (auto ChunkIndexIt = Lookup.ChunkHashToChunkIndex.find(RawHash); ChunkIndexIt != Lookup.ChunkHashToChunkIndex.end())
- {
- const uint32_t ChunkIndex = ChunkIndexIt->second;
- if (auto LooseOrderIndexIt = ChunkIndexToLooseChunkOrderIndex.find(ChunkIndex);
- LooseOrderIndexIt != ChunkIndexToLooseChunkOrderIndex.end())
- {
- LooseChunkOrderIndexes.push_back(LooseOrderIndexIt->second);
- TotalLooseChunksSize += Content.ChunkedContent.ChunkRawSizes[ChunkIndex];
- }
- }
- else
+ Result.BlockIndexes.push_back(It->second);
+ Result.TotalBlocksSize += NewBlocks.BlockSizes[It->second];
+ }
+ else if (auto ChunkIndexIt = Lookup.ChunkHashToChunkIndex.find(RawHash); ChunkIndexIt != Lookup.ChunkHashToChunkIndex.end())
+ {
+ const uint32_t ChunkIndex = ChunkIndexIt->second;
+ if (auto LooseOrderIndexIt = ChunkIndexToLooseChunkOrderIndex.find(ChunkIndex);
+ LooseOrderIndexIt != ChunkIndexToLooseChunkOrderIndex.end())
{
- OutUnknownChunks.push_back(RawHash);
+ Result.LooseChunkOrderIndexes.push_back(LooseOrderIndexIt->second);
+ Result.TotalLooseChunksSize += Content.ChunkedContent.ChunkRawSizes[ChunkIndex];
}
}
- if (BlockIndexes.empty() && LooseChunkOrderIndexes.empty())
+ else
{
- return;
+ OutUnknownChunks.push_back(RawHash);
}
+ }
+ return Result;
+}
- uint64_t TotalRawSize = TotalLooseChunksSize + TotalBlocksSize;
-
- const size_t UploadBlockCount = BlockIndexes.size();
- const uint32_t UploadChunkCount = gsl::narrow<uint32_t>(LooseChunkOrderIndexes.size());
-
- auto AsyncUploadBlock = [this,
- &Work,
- &NewBlocks,
- UploadBlockCount,
- &UploadedBlockCount,
- UploadChunkCount,
- &UploadedChunkCount,
- &UploadedBlockSize,
- &TempUploadStats,
- &FilteredUploadedBytesPerSecond,
- &UploadChunkPool](const size_t BlockIndex,
- const IoHash BlockHash,
- CompositeBuffer&& Payload,
- std::atomic<uint64_t>& QueuedPendingInMemoryBlocksForUpload) {
- bool IsInMemoryBlock = true;
- if (QueuedPendingInMemoryBlocksForUpload.load() > 16)
- {
- ZEN_TRACE_CPU("AsyncUploadBlock_WriteTempBlock");
- std::filesystem::path TempFilePath = m_Options.TempDir / (BlockHash.ToHexString());
- Payload = CompositeBuffer(WriteToTempFile(std::move(Payload), TempFilePath));
- IsInMemoryBlock = false;
- }
- else
- {
- QueuedPendingInMemoryBlocksForUpload++;
- }
+void
+BuildsOperationUploadFolder::ScheduleBlockGenerationAndUpload(UploadPartBlobsContext& Context, std::span<const size_t> BlockIndexes)
+{
+ for (const size_t BlockIndex : BlockIndexes)
+ {
+ const IoHash& BlockHash = Context.NewBlocks.BlockDescriptions[BlockIndex].BlockHash;
+ if (m_AbortFlag)
+ {
+ break;
+ }
+ Context.Work.ScheduleWork(
+ Context.ReadChunkPool,
+ [this, &Context, BlockHash = IoHash(BlockHash), BlockIndex, GenerateBlockCount = BlockIndexes.size()](std::atomic<bool>&) {
+ if (m_AbortFlag)
+ {
+ return;
+ }
+ ZEN_TRACE_CPU("UploadPartBlobs_GenerateBlock");
- Work.ScheduleWork(
- UploadChunkPool,
- [this,
- &QueuedPendingInMemoryBlocksForUpload,
- &NewBlocks,
- UploadBlockCount,
- &UploadedBlockCount,
- UploadChunkCount,
- &UploadedChunkCount,
- &UploadedBlockSize,
- &TempUploadStats,
- &FilteredUploadedBytesPerSecond,
- IsInMemoryBlock,
- BlockIndex,
- BlockHash,
- Payload = CompositeBuffer(std::move(Payload))](std::atomic<bool>&) mutable {
- auto _ = MakeGuard([IsInMemoryBlock, &QueuedPendingInMemoryBlocksForUpload] {
- if (IsInMemoryBlock)
- {
- QueuedPendingInMemoryBlocksForUpload--;
- }
- });
- if (!m_AbortFlag)
+ Context.FilteredGenerateBlockBytesPerSecond.Start();
+
+ Stopwatch GenerateTimer;
+ CompositeBuffer Payload;
+ if (Context.NewBlocks.BlockHeaders[BlockIndex])
+ {
+ Payload = RebuildBlock(Context.Content,
+ Context.Lookup,
+ std::move(Context.NewBlocks.BlockHeaders[BlockIndex]),
+ Context.NewBlockChunks[BlockIndex])
+ .GetCompressed();
+ }
+ else
+ {
+ ChunkBlockDescription BlockDescription;
+ CompressedBuffer CompressedBlock =
+ GenerateBlock(Context.Content, Context.Lookup, Context.NewBlockChunks[BlockIndex], BlockDescription);
+ if (!CompressedBlock)
{
- ZEN_TRACE_CPU("AsyncUploadBlock");
+ throw std::runtime_error(fmt::format("Failed generating block {}", BlockHash));
+ }
+ ZEN_ASSERT(BlockDescription.BlockHash == BlockHash);
+ Payload = std::move(CompressedBlock).GetCompressed();
+ }
- const uint64_t PayloadSize = Payload.GetSize();
+ Context.GeneratedBlockByteCount += Context.NewBlocks.BlockSizes[BlockIndex];
+ if (Context.GeneratedBlockCount.fetch_add(1) + 1 == GenerateBlockCount)
+ {
+ Context.FilteredGenerateBlockBytesPerSecond.Stop();
+ }
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("{} block {} ({}) containing {} chunks in {}",
+ Context.NewBlocks.BlockHeaders[BlockIndex] ? "Regenerated" : "Generated",
+ Context.NewBlocks.BlockDescriptions[BlockIndex].BlockHash,
+ NiceBytes(Context.NewBlocks.BlockSizes[BlockIndex]),
+ Context.NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(),
+ NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs()));
+ }
+ if (!m_AbortFlag)
+ {
+ UploadBlockPayload(Context, BlockIndex, BlockHash, std::move(Payload));
+ }
+ });
+ }
+}
- FilteredUploadedBytesPerSecond.Start();
- const CbObject BlockMetaData =
- BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]);
+void
+BuildsOperationUploadFolder::UploadBlockPayload(UploadPartBlobsContext& Context,
+ size_t BlockIndex,
+ const IoHash& BlockHash,
+ CompositeBuffer Payload)
+{
+ bool IsInMemoryBlock = true;
+ if (Context.QueuedPendingInMemoryBlocksForUpload.load() > 16)
+ {
+ ZEN_TRACE_CPU("AsyncUploadBlock_WriteTempBlock");
+ std::filesystem::path TempFilePath = m_Options.TempDir / (BlockHash.ToHexString());
+ Payload = CompositeBuffer(WriteToTempFile(std::move(Payload), TempFilePath));
+ IsInMemoryBlock = false;
+ }
+ else
+ {
+ Context.QueuedPendingInMemoryBlocksForUpload++;
+ }
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
- }
- m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} ({}) containing {} chunks",
- BlockHash,
- NiceBytes(PayloadSize),
- NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
- }
- UploadedBlockSize += PayloadSize;
- TempUploadStats.BlocksBytes += PayloadSize;
+ Context.Work.ScheduleWork(
+ Context.UploadChunkPool,
+ [this, &Context, IsInMemoryBlock, BlockIndex, BlockHash = IoHash(BlockHash), Payload = CompositeBuffer(std::move(Payload))](
+ std::atomic<bool>&) {
+ auto _ = MakeGuard([IsInMemoryBlock, &Context] {
+ if (IsInMemoryBlock)
+ {
+ Context.QueuedPendingInMemoryBlocksForUpload--;
+ }
+ });
+ if (m_AbortFlag)
+ {
+ return;
+ }
+ ZEN_TRACE_CPU("AsyncUploadBlock");
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
- }
- bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
- if (MetadataSucceeded)
- {
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} metadata ({})",
- BlockHash,
- NiceBytes(BlockMetaData.GetSize()));
- }
+ const uint64_t PayloadSize = Payload.GetSize();
- NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
- TempUploadStats.BlocksBytes += BlockMetaData.GetSize();
- }
+ Context.FilteredUploadedBytesPerSecond.Start();
+ const CbObject BlockMetaData =
+ BuildChunkBlockDescription(Context.NewBlocks.BlockDescriptions[BlockIndex], Context.NewBlocks.BlockMetaDatas[BlockIndex]);
- TempUploadStats.BlockCount++;
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
+ }
- UploadedBlockCount++;
- if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
- {
- FilteredUploadedBytesPerSecond.Stop();
- }
- }
- });
- };
+ try
+ {
+ m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
- auto AsyncUploadLooseChunk = [this,
- LargeAttachmentSize,
- &Work,
- &UploadChunkPool,
- &FilteredUploadedBytesPerSecond,
- &UploadedBlockCount,
- &UploadedChunkCount,
- UploadBlockCount,
- UploadChunkCount,
- &UploadedCompressedChunkSize,
- &UploadedRawChunkSize,
- &TempUploadStats](const IoHash& RawHash, uint64_t RawSize, CompositeBuffer&& Payload) {
- Work.ScheduleWork(
- UploadChunkPool,
- [this,
- &Work,
- LargeAttachmentSize,
- &FilteredUploadedBytesPerSecond,
- &UploadChunkPool,
- &UploadedBlockCount,
- &UploadedChunkCount,
- UploadBlockCount,
- UploadChunkCount,
- &UploadedCompressedChunkSize,
- &UploadedRawChunkSize,
- &TempUploadStats,
- RawHash,
- RawSize,
- Payload = CompositeBuffer(std::move(Payload))](std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("AsyncUploadLooseChunk");
+ if (m_AbortFlag)
+ {
+ return;
+ }
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Uploaded block {} ({}) containing {} chunks",
+ BlockHash,
+ NiceBytes(PayloadSize),
+ Context.NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
+ }
+ Context.UploadedBlockSize += PayloadSize;
+ Context.TempUploadStats.BlocksBytes += PayloadSize;
- const uint64_t PayloadSize = Payload.GetSize();
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
+ }
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
- }
+ bool MetadataSucceeded = false;
+ try
+ {
+ MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
+ if (m_AbortFlag)
+ {
+ return;
+ }
+ if (MetadataSucceeded)
+ {
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Uploaded block {} metadata ({})", BlockHash, NiceBytes(BlockMetaData.GetSize()));
+ }
+ Context.NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
+ Context.TempUploadStats.BlocksBytes += BlockMetaData.GetSize();
+ }
- if (PayloadSize >= LargeAttachmentSize)
- {
- ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart");
- TempUploadStats.MultipartAttachmentCount++;
- std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob(
- m_BuildId,
- RawHash,
- ZenContentType::kCompressedBinary,
- PayloadSize,
- [Payload = std::move(Payload), &FilteredUploadedBytesPerSecond](uint64_t Offset,
- uint64_t Size) mutable -> IoBuffer {
- FilteredUploadedBytesPerSecond.Start();
-
- IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer();
- PartPayload.SetContentType(ZenContentType::kBinary);
- return PartPayload;
- },
- [RawSize,
- &TempUploadStats,
- &UploadedCompressedChunkSize,
- &UploadChunkPool,
- &UploadedBlockCount,
- UploadBlockCount,
- &UploadedChunkCount,
- UploadChunkCount,
- &FilteredUploadedBytesPerSecond,
- &UploadedRawChunkSize](uint64_t SentBytes, bool IsComplete) {
- TempUploadStats.ChunksBytes += SentBytes;
- UploadedCompressedChunkSize += SentBytes;
- if (IsComplete)
- {
- TempUploadStats.ChunkCount++;
- UploadedChunkCount++;
- if (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
- {
- FilteredUploadedBytesPerSecond.Stop();
- }
- UploadedRawChunkSize += RawSize;
- }
- });
- for (auto& WorkPart : MultipartWork)
- {
- Work.ScheduleWork(UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) {
- ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work");
- if (!AbortFlag)
- {
- Work();
- }
- });
- }
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded multipart chunk {} ({})", RawHash, NiceBytes(PayloadSize));
- }
- }
- else
- {
- ZEN_TRACE_CPU("AsyncUploadLooseChunk_Singlepart");
- m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize));
- }
- TempUploadStats.ChunksBytes += Payload.GetSize();
- TempUploadStats.ChunkCount++;
- UploadedCompressedChunkSize += Payload.GetSize();
- UploadedRawChunkSize += RawSize;
- UploadedChunkCount++;
- if (UploadedChunkCount == UploadChunkCount)
- {
- FilteredUploadedBytesPerSecond.Stop();
- }
- }
- }
- });
- };
+ Context.TempUploadStats.BlockCount++;
- std::vector<size_t> GenerateBlockIndexes;
+ if (Context.UploadedBlockCount.fetch_add(1) + 1 == Context.UploadBlockCount &&
+ Context.UploadedChunkCount == Context.UploadChunkCount)
+ {
+ Context.FilteredUploadedBytesPerSecond.Stop();
+ }
+ });
+}
- std::atomic<uint64_t> GeneratedBlockCount = 0;
- std::atomic<uint64_t> GeneratedBlockByteCount = 0;
+void
+BuildsOperationUploadFolder::ScheduleLooseChunkCompressionAndUpload(UploadPartBlobsContext& Context,
+ std::span<const uint32_t> LooseChunkOrderIndexes)
+{
+ for (const uint32_t LooseChunkOrderIndex : LooseChunkOrderIndexes)
+ {
+ const uint32_t ChunkIndex = Context.LooseChunkIndexes[LooseChunkOrderIndex];
+ Context.Work.ScheduleWork(Context.ReadChunkPool,
+ [this, &Context, LooseChunkOrderCount = LooseChunkOrderIndexes.size(), ChunkIndex](std::atomic<bool>&) {
+ if (m_AbortFlag)
+ {
+ return;
+ }
+ ZEN_TRACE_CPU("UploadPartBlobs_CompressChunk");
- std::atomic<uint64_t> QueuedPendingInMemoryBlocksForUpload = 0;
+ Context.FilteredCompressedBytesPerSecond.Start();
+ Stopwatch CompressTimer;
+ CompositeBuffer Payload =
+ CompressChunk(Context.Content, Context.Lookup, ChunkIndex, Context.TempLooseChunksStats);
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Compressed chunk {} ({} -> {}) in {}",
+ Context.Content.ChunkedContent.ChunkHashes[ChunkIndex],
+ NiceBytes(Context.Content.ChunkedContent.ChunkRawSizes[ChunkIndex]),
+ NiceBytes(Payload.GetSize()),
+ NiceTimeSpanMs(CompressTimer.GetElapsedTimeMs()));
+ }
+ const uint64_t ChunkRawSize = Context.Content.ChunkedContent.ChunkRawSizes[ChunkIndex];
+ Context.TempUploadStats.ReadFromDiskBytes += ChunkRawSize;
+ if (Context.TempLooseChunksStats.CompressedChunkCount == LooseChunkOrderCount)
+ {
+ Context.FilteredCompressedBytesPerSecond.Stop();
+ }
+ if (!m_AbortFlag)
+ {
+ UploadLooseChunkPayload(Context,
+ Context.Content.ChunkedContent.ChunkHashes[ChunkIndex],
+ ChunkRawSize,
+ std::move(Payload));
+ }
+ });
+ }
+}
- // Start generation of any non-prebuilt blocks and schedule upload
- for (const size_t BlockIndex : BlockIndexes)
- {
- const IoHash& BlockHash = NewBlocks.BlockDescriptions[BlockIndex].BlockHash;
- if (!m_AbortFlag)
+void
+BuildsOperationUploadFolder::UploadLooseChunkPayload(UploadPartBlobsContext& Context,
+ const IoHash& RawHash,
+ uint64_t RawSize,
+ CompositeBuffer Payload)
+{
+ Context.Work.ScheduleWork(
+ Context.UploadChunkPool,
+ [this, &Context, RawHash = IoHash(RawHash), RawSize, Payload = CompositeBuffer(std::move(Payload))](std::atomic<bool>&) mutable {
+ if (m_AbortFlag)
{
- Work.ScheduleWork(
- ReadChunkPool,
- [this,
- BlockHash = IoHash(BlockHash),
- BlockIndex,
- &FilteredGenerateBlockBytesPerSecond,
- &Content,
- &Lookup,
- &NewBlocks,
- &NewBlockChunks,
- &GenerateBlockIndexes,
- &GeneratedBlockCount,
- &GeneratedBlockByteCount,
- &AsyncUploadBlock,
- &QueuedPendingInMemoryBlocksForUpload](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("UploadPartBlobs_GenerateBlock");
+ return;
+ }
+ ZEN_TRACE_CPU("AsyncUploadLooseChunk");
- FilteredGenerateBlockBytesPerSecond.Start();
+ const uint64_t PayloadSize = Payload.GetSize();
- Stopwatch GenerateTimer;
- CompositeBuffer Payload;
- if (NewBlocks.BlockHeaders[BlockIndex])
- {
- Payload =
- RebuildBlock(Content, Lookup, std::move(NewBlocks.BlockHeaders[BlockIndex]), NewBlockChunks[BlockIndex])
- .GetCompressed();
- }
- else
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
+ }
+
+ if (PayloadSize >= Context.LargeAttachmentSize)
+ {
+ ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart");
+ Context.TempUploadStats.MultipartAttachmentCount++;
+ try
+ {
+ std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob(
+ m_BuildId,
+ RawHash,
+ ZenContentType::kCompressedBinary,
+ PayloadSize,
+ [Payload = std::move(Payload), &Context](uint64_t Offset, uint64_t Size) -> IoBuffer {
+ Context.FilteredUploadedBytesPerSecond.Start();
+
+ IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer();
+ PartPayload.SetContentType(ZenContentType::kBinary);
+ return PartPayload;
+ },
+ [&Context, RawSize](uint64_t SentBytes, bool IsComplete) {
+ Context.TempUploadStats.ChunksBytes += SentBytes;
+ Context.UploadedCompressedChunkSize += SentBytes;
+ if (IsComplete)
{
- ChunkBlockDescription BlockDescription;
- CompressedBuffer CompressedBlock =
- GenerateBlock(Content, Lookup, NewBlockChunks[BlockIndex], BlockDescription);
- if (!CompressedBlock)
+ Context.TempUploadStats.ChunkCount++;
+ if (Context.UploadedChunkCount.fetch_add(1) + 1 == Context.UploadChunkCount &&
+ Context.UploadedBlockCount == Context.UploadBlockCount)
{
- throw std::runtime_error(fmt::format("Failed generating block {}", BlockHash));
+ Context.FilteredUploadedBytesPerSecond.Stop();
}
- ZEN_ASSERT(BlockDescription.BlockHash == BlockHash);
- Payload = std::move(CompressedBlock).GetCompressed();
+ Context.UploadedRawChunkSize += RawSize;
}
-
- GeneratedBlockByteCount += NewBlocks.BlockSizes[BlockIndex];
- GeneratedBlockCount++;
- if (GeneratedBlockCount == GenerateBlockIndexes.size())
- {
- FilteredGenerateBlockBytesPerSecond.Stop();
- }
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "{} block {} ({}) containing {} chunks in {}",
- NewBlocks.BlockHeaders[BlockIndex] ? "Regenerated" : "Generated",
- NewBlocks.BlockDescriptions[BlockIndex].BlockHash,
- NiceBytes(NewBlocks.BlockSizes[BlockIndex]),
- NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size(),
- NiceTimeSpanMs(GenerateTimer.GetElapsedTimeMs()));
- }
- if (!m_AbortFlag)
+ });
+ for (auto& WorkPart : MultipartWork)
+ {
+ Context.Work.ScheduleWork(Context.UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) {
+ ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work");
+ if (!AbortFlag)
{
- AsyncUploadBlock(BlockIndex, BlockHash, std::move(Payload), QueuedPendingInMemoryBlocksForUpload);
+ Work();
}
- }
- });
- }
- }
-
- // Start compression of any non-precompressed loose chunks and schedule upload
- for (const uint32_t LooseChunkOrderIndex : LooseChunkOrderIndexes)
- {
- const uint32_t ChunkIndex = LooseChunkIndexes[LooseChunkOrderIndex];
- Work.ScheduleWork(
- ReadChunkPool,
- [this,
- &Content,
- &Lookup,
- &TempLooseChunksStats,
- &LooseChunkOrderIndexes,
- &FilteredCompressedBytesPerSecond,
- &TempUploadStats,
- &AsyncUploadLooseChunk,
- ChunkIndex](std::atomic<bool>&) {
+ });
+ }
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Uploaded multipart chunk {} ({})", RawHash, NiceBytes(PayloadSize));
+ }
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
if (!m_AbortFlag)
{
- ZEN_TRACE_CPU("UploadPartBlobs_CompressChunk");
-
- FilteredCompressedBytesPerSecond.Start();
- Stopwatch CompressTimer;
- CompositeBuffer Payload = CompressChunk(Content, Lookup, ChunkIndex, TempLooseChunksStats);
- if (m_Options.IsVerbose)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Compressed chunk {} ({} -> {}) in {}",
- Content.ChunkedContent.ChunkHashes[ChunkIndex],
- NiceBytes(Content.ChunkedContent.ChunkRawSizes[ChunkIndex]),
- NiceBytes(Payload.GetSize()),
- NiceTimeSpanMs(CompressTimer.GetElapsedTimeMs()));
- }
- const uint64_t ChunkRawSize = Content.ChunkedContent.ChunkRawSizes[ChunkIndex];
- TempUploadStats.ReadFromDiskBytes += ChunkRawSize;
- if (TempLooseChunksStats.CompressedChunkCount == LooseChunkOrderIndexes.size())
- {
- FilteredCompressedBytesPerSecond.Stop();
- }
- if (!m_AbortFlag)
- {
- AsyncUploadLooseChunk(Content.ChunkedContent.ChunkHashes[ChunkIndex], ChunkRawSize, std::move(Payload));
- }
+ throw;
}
- });
- }
+ }
+ return;
+ }
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
- ZEN_UNUSED(PendingWork);
- FilteredCompressedBytesPerSecond.Update(TempLooseChunksStats.CompressedChunkRawBytes.load());
- FilteredGenerateBlockBytesPerSecond.Update(GeneratedBlockByteCount.load());
- FilteredUploadedBytesPerSecond.Update(UploadedCompressedChunkSize.load() + UploadedBlockSize.load());
- uint64_t UploadedRawSize = UploadedRawChunkSize.load() + UploadedBlockSize.load();
- uint64_t UploadedCompressedSize = UploadedCompressedChunkSize.load() + UploadedBlockSize.load();
-
- std::string Details = fmt::format(
- "Compressed {}/{} ({}/{}{}) chunks. "
- "Uploaded {}/{} ({}/{}) blobs "
- "({}{})",
- TempLooseChunksStats.CompressedChunkCount.load(),
- LooseChunkOrderIndexes.size(),
- NiceBytes(TempLooseChunksStats.CompressedChunkRawBytes),
- NiceBytes(TotalLooseChunksSize),
- (TempLooseChunksStats.CompressedChunkCount == LooseChunkOrderIndexes.size())
- ? ""
- : fmt::format(" {}B/s", NiceNum(FilteredCompressedBytesPerSecond.GetCurrent())),
-
- UploadedBlockCount.load() + UploadedChunkCount.load(),
- UploadBlockCount + UploadChunkCount,
- NiceBytes(UploadedRawSize),
- NiceBytes(TotalRawSize),
-
- NiceBytes(UploadedCompressedSize),
- (UploadedBlockCount == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
- ? ""
- : fmt::format(" {}bits/s", NiceNum(FilteredUploadedBytesPerSecond.GetCurrent())));
-
- Progress.UpdateState({.Task = "Uploading blobs ",
- .Details = Details,
- .TotalCount = gsl::narrow<uint64_t>(TotalRawSize),
- .RemainingCount = gsl::narrow<uint64_t>(TotalRawSize - UploadedRawSize),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ZEN_TRACE_CPU("AsyncUploadLooseChunk_Singlepart");
+ try
+ {
+ m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
+ if (m_AbortFlag)
+ {
+ return;
+ }
+ if (m_Options.IsVerbose)
+ {
+ ZEN_INFO("Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize));
+ }
+ Context.TempUploadStats.ChunksBytes += Payload.GetSize();
+ Context.TempUploadStats.ChunkCount++;
+ Context.UploadedCompressedChunkSize += Payload.GetSize();
+ Context.UploadedRawChunkSize += RawSize;
+ if (Context.UploadedChunkCount.fetch_add(1) + 1 == Context.UploadChunkCount &&
+ Context.UploadedBlockCount == Context.UploadBlockCount)
+ {
+ Context.FilteredUploadedBytesPerSecond.Stop();
+ }
});
-
- ZEN_ASSERT(m_AbortFlag || QueuedPendingInMemoryBlocksForUpload.load() == 0);
-
- Progress.Finish();
-
- TempUploadStats.ElapsedWallTimeUS += FilteredUploadedBytesPerSecond.GetElapsedTimeUS();
- TempLooseChunksStats.CompressChunksElapsedWallTimeUS += FilteredCompressedBytesPerSecond.GetElapsedTimeUS();
- }
}
CompositeBuffer
@@ -6432,7 +6624,7 @@ BuildsOperationUploadFolder::CompressChunk(const ChunkedFolderContent& Content,
throw std::runtime_error(fmt::format("Fetched chunk {} has invalid size", ChunkHash));
}
- const bool ShouldCompressChunk = IsChunkCompressable(m_NonCompressableExtensionHashes, Content, Lookup, ChunkIndex);
+ const bool ShouldCompressChunk = IsChunkCompressable(m_NonCompressableExtensionHashes, Lookup, ChunkIndex);
const OodleCompressionLevel CompressionLevel = ShouldCompressChunk ? OodleCompressionLevel::VeryFast : OodleCompressionLevel::None;
if (ShouldCompressChunk)
@@ -6521,7 +6713,8 @@ BuildsOperationUploadFolder::CompressChunk(const ChunkedFolderContent& Content,
return std::move(CompressedBlob).GetCompressed();
}
-BuildsOperationValidateBuildPart::BuildsOperationValidateBuildPart(OperationLogOutput& OperationLogOutput,
+BuildsOperationValidateBuildPart::BuildsOperationValidateBuildPart(LoggerRef Log,
+ ProgressBase& Progress,
BuildStorageBase& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -6532,7 +6725,8 @@ BuildsOperationValidateBuildPart::BuildsOperationValidateBuildPart(OperationLogO
const std::string_view BuildPartName,
const Options& Options)
-: m_LogOutput(OperationLogOutput)
+: m_Log(Log)
+, m_Progress(Progress)
, m_Storage(Storage)
, m_AbortFlag(AbortFlag)
, m_PauseFlag(PauseFlag)
@@ -6551,89 +6745,24 @@ BuildsOperationValidateBuildPart::Execute()
ZEN_TRACE_CPU("ValidateBuildPart");
try
{
- enum class TaskSteps : uint32_t
- {
- FetchBuild,
- FetchBuildPart,
- ValidateBlobs,
- Cleanup,
- StepCount
- };
-
auto EndProgress =
- MakeGuard([&]() { m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); });
+ MakeGuard([&]() { m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); });
Stopwatch Timer;
auto _ = MakeGuard([&]() {
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Validated build part {}/{} ('{}') in {}",
- m_BuildId,
- m_BuildPartId,
- m_BuildPartName,
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Validated build part {}/{} ('{}') in {}",
+ m_BuildId,
+ m_BuildPartId,
+ m_BuildPartName,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
});
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuild, (uint32_t)TaskSteps::StepCount);
-
- CbObject Build = m_Storage.GetBuild(m_BuildId);
- if (!m_BuildPartName.empty())
- {
- m_BuildPartId = Build["parts"sv].AsObjectView()[m_BuildPartName].AsObjectId();
- if (m_BuildPartId == Oid::Zero)
- {
- throw std::runtime_error(fmt::format("Build {} does not have a part named '{}'", m_BuildId, m_BuildPartName));
- }
- }
- m_ValidateStats.BuildBlobSize = Build.GetSize();
- uint64_t PreferredMultipartChunkSize = 32u * 1024u * 1024u;
- if (auto ChunkSize = Build["chunkSize"sv].AsUInt64(); ChunkSize != 0)
- {
- PreferredMultipartChunkSize = ChunkSize;
- }
-
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuildPart, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuild, (uint32_t)TaskSteps::StepCount);
- CbObject BuildPart = m_Storage.GetBuildPart(m_BuildId, m_BuildPartId);
- m_ValidateStats.BuildPartSize = BuildPart.GetSize();
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Validating build part {}/{} ({})",
- m_BuildId,
- m_BuildPartId,
- NiceBytes(BuildPart.GetSize()));
- }
- std::vector<IoHash> ChunkAttachments;
- if (const CbObjectView ChunkAttachmentsView = BuildPart["chunkAttachments"sv].AsObjectView())
- {
- for (CbFieldView LooseFileView : ChunkAttachmentsView["rawHashes"sv])
- {
- ChunkAttachments.push_back(LooseFileView.AsBinaryAttachment());
- }
- }
- m_ValidateStats.ChunkAttachmentCount = ChunkAttachments.size();
- std::vector<IoHash> BlockAttachments;
- if (const CbObjectView BlockAttachmentsView = BuildPart["blockAttachments"sv].AsObjectView())
- {
- {
- for (CbFieldView BlocksView : BlockAttachmentsView["rawHashes"sv])
- {
- BlockAttachments.push_back(BlocksView.AsBinaryAttachment());
- }
- }
- }
- m_ValidateStats.BlockAttachmentCount = BlockAttachments.size();
-
- std::vector<ChunkBlockDescription> VerifyBlockDescriptions =
- ParseChunkBlockDescriptionList(m_Storage.GetBlockMetadatas(m_BuildId, BlockAttachments));
- if (VerifyBlockDescriptions.size() != BlockAttachments.size())
- {
- throw std::runtime_error(fmt::format("Uploaded blocks metadata could not all be found, {} blocks metadata is missing",
- BlockAttachments.size() - VerifyBlockDescriptions.size()));
- }
+ ResolvedBuildPart Resolved = ResolveBuildPart();
ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
@@ -6643,150 +6772,23 @@ BuildsOperationValidateBuildPart::Execute()
CreateDirectories(TempFolder);
auto __ = MakeGuard([this, TempFolder]() { CleanAndRemoveDirectory(m_IOWorkerPool, m_AbortFlag, m_PauseFlag, TempFolder); });
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::ValidateBlobs, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::ValidateBlobs, (uint32_t)TaskSteps::StepCount);
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Validate Blobs"));
- OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Validate Blobs");
- uint64_t AttachmentsToVerifyCount = ChunkAttachments.size() + BlockAttachments.size();
- FilteredRate FilteredDownloadedBytesPerSecond;
- FilteredRate FilteredVerifiedBytesPerSecond;
+ const uint64_t AttachmentsToVerifyCount = Resolved.ChunkAttachments.size() + Resolved.BlockAttachments.size();
+ FilteredRate FilteredDownloadedBytesPerSecond;
+ FilteredRate FilteredVerifiedBytesPerSecond;
- std::atomic<uint64_t> MultipartAttachmentCount = 0;
+ ValidateBlobsContext Context{.Work = Work,
+ .AttachmentsToVerifyCount = AttachmentsToVerifyCount,
+ .FilteredDownloadedBytesPerSecond = FilteredDownloadedBytesPerSecond,
+ .FilteredVerifiedBytesPerSecond = FilteredVerifiedBytesPerSecond};
- for (const IoHash& ChunkAttachment : ChunkAttachments)
- {
- Work.ScheduleWork(
- m_NetworkPool,
- [this,
- &Work,
- AttachmentsToVerifyCount,
- &TempFolder,
- PreferredMultipartChunkSize,
- &FilteredDownloadedBytesPerSecond,
- &FilteredVerifiedBytesPerSecond,
- &ChunkAttachments,
- ChunkAttachment = IoHash(ChunkAttachment)](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("ValidateBuildPart_GetChunk");
-
- FilteredDownloadedBytesPerSecond.Start();
- DownloadLargeBlob(
- m_Storage,
- TempFolder,
- m_BuildId,
- ChunkAttachment,
- PreferredMultipartChunkSize,
- Work,
- m_NetworkPool,
- m_DownloadStats.DownloadedChunkByteCount,
- m_DownloadStats.MultipartAttachmentCount,
- [this,
- &Work,
- AttachmentsToVerifyCount,
- &FilteredDownloadedBytesPerSecond,
- &FilteredVerifiedBytesPerSecond,
- ChunkHash = IoHash(ChunkAttachment)](IoBuffer&& Payload) {
- m_DownloadStats.DownloadedChunkCount++;
- Payload.SetContentType(ZenContentType::kCompressedBinary);
- if (!m_AbortFlag)
- {
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- AttachmentsToVerifyCount,
- &FilteredDownloadedBytesPerSecond,
- &FilteredVerifiedBytesPerSecond,
- Payload = IoBuffer(std::move(Payload)),
- ChunkHash](std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("ValidateBuildPart_Validate");
-
- if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount ==
- AttachmentsToVerifyCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
-
- FilteredVerifiedBytesPerSecond.Start();
-
- uint64_t CompressedSize;
- uint64_t DecompressedSize;
- ValidateBlob(m_AbortFlag, std::move(Payload), ChunkHash, CompressedSize, DecompressedSize);
- m_ValidateStats.VerifiedAttachmentCount++;
- m_ValidateStats.VerifiedByteCount += DecompressedSize;
- if (m_ValidateStats.VerifiedAttachmentCount.load() == AttachmentsToVerifyCount)
- {
- FilteredVerifiedBytesPerSecond.Stop();
- }
- }
- });
- }
- });
- }
- });
- }
-
- for (const IoHash& BlockAttachment : BlockAttachments)
- {
- Work.ScheduleWork(
- m_NetworkPool,
- [this,
- &Work,
- AttachmentsToVerifyCount,
- &FilteredDownloadedBytesPerSecond,
- &FilteredVerifiedBytesPerSecond,
- BlockAttachment = IoHash(BlockAttachment)](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("ValidateBuildPart_GetBlock");
-
- FilteredDownloadedBytesPerSecond.Start();
- IoBuffer Payload = m_Storage.GetBuildBlob(m_BuildId, BlockAttachment);
- m_DownloadStats.DownloadedBlockCount++;
- m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
- if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == AttachmentsToVerifyCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
- if (!Payload)
- {
- throw std::runtime_error(fmt::format("Block attachment {} could not be found", BlockAttachment));
- }
- if (!m_AbortFlag)
- {
- Work.ScheduleWork(
- m_IOWorkerPool,
- [this,
- &FilteredVerifiedBytesPerSecond,
- AttachmentsToVerifyCount,
- Payload = std::move(Payload),
- BlockAttachment](std::atomic<bool>&) mutable {
- if (!m_AbortFlag)
- {
- ZEN_TRACE_CPU("ValidateBuildPart_ValidateBlock");
-
- FilteredVerifiedBytesPerSecond.Start();
-
- uint64_t CompressedSize;
- uint64_t DecompressedSize;
- ValidateChunkBlock(std::move(Payload), BlockAttachment, CompressedSize, DecompressedSize);
- m_ValidateStats.VerifiedAttachmentCount++;
- m_ValidateStats.VerifiedByteCount += DecompressedSize;
- if (m_ValidateStats.VerifiedAttachmentCount.load() == AttachmentsToVerifyCount)
- {
- FilteredVerifiedBytesPerSecond.Stop();
- }
- }
- });
- }
- }
- });
- }
+ ScheduleChunkAttachmentValidation(Context, Resolved.ChunkAttachments, TempFolder, Resolved.PreferredMultipartChunkSize);
+ ScheduleBlockAttachmentValidation(Context, Resolved.BlockAttachments);
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
ZEN_UNUSED(PendingWork);
const uint64_t DownloadedAttachmentCount = m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount;
@@ -6805,20 +6807,20 @@ BuildsOperationValidateBuildPart::Execute()
NiceBytes(m_ValidateStats.VerifiedByteCount.load()),
NiceNum(FilteredVerifiedBytesPerSecond.GetCurrent()));
- Progress.UpdateState(
+ ProgressBar->UpdateState(
{.Task = "Validating blobs ",
.Details = Details,
.TotalCount = gsl::narrow<uint64_t>(AttachmentsToVerifyCount * 2),
.RemainingCount = gsl::narrow<uint64_t>(AttachmentsToVerifyCount * 2 -
(DownloadedAttachmentCount + m_ValidateStats.VerifiedAttachmentCount.load())),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
false);
});
- Progress.Finish();
+ ProgressBar->Finish();
m_ValidateStats.ElapsedWallTimeUS = Timer.GetElapsedTimeUs();
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount);
}
catch (const std::exception&)
{
@@ -6827,7 +6829,189 @@ BuildsOperationValidateBuildPart::Execute()
}
}
-BuildsOperationPrimeCache::BuildsOperationPrimeCache(OperationLogOutput& OperationLogOutput,
+BuildsOperationValidateBuildPart::ResolvedBuildPart
+BuildsOperationValidateBuildPart::ResolveBuildPart()
+{
+ ResolvedBuildPart Result;
+ Result.PreferredMultipartChunkSize = 32u * 1024u * 1024u;
+
+ CbObject Build = m_Storage.GetBuild(m_BuildId);
+ if (!m_BuildPartName.empty())
+ {
+ m_BuildPartId = Build["parts"sv].AsObjectView()[m_BuildPartName].AsObjectId();
+ if (m_BuildPartId == Oid::Zero)
+ {
+ throw std::runtime_error(fmt::format("Build {} does not have a part named '{}'", m_BuildId, m_BuildPartName));
+ }
+ }
+ m_ValidateStats.BuildBlobSize = Build.GetSize();
+ if (auto ChunkSize = Build["chunkSize"sv].AsUInt64(); ChunkSize != 0)
+ {
+ Result.PreferredMultipartChunkSize = ChunkSize;
+ }
+
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::FetchBuildPart, (uint32_t)TaskSteps::StepCount);
+
+ CbObject BuildPart = m_Storage.GetBuildPart(m_BuildId, m_BuildPartId);
+ m_ValidateStats.BuildPartSize = BuildPart.GetSize();
+ if (!m_Options.IsQuiet)
+ {
+ ZEN_INFO("Validating build part {}/{} ({})", m_BuildId, m_BuildPartId, NiceBytes(BuildPart.GetSize()));
+ }
+ if (const CbObjectView ChunkAttachmentsView = BuildPart["chunkAttachments"sv].AsObjectView())
+ {
+ for (CbFieldView LooseFileView : ChunkAttachmentsView["rawHashes"sv])
+ {
+ Result.ChunkAttachments.push_back(LooseFileView.AsBinaryAttachment());
+ }
+ }
+ m_ValidateStats.ChunkAttachmentCount = Result.ChunkAttachments.size();
+ if (const CbObjectView BlockAttachmentsView = BuildPart["blockAttachments"sv].AsObjectView())
+ {
+ for (CbFieldView BlocksView : BlockAttachmentsView["rawHashes"sv])
+ {
+ Result.BlockAttachments.push_back(BlocksView.AsBinaryAttachment());
+ }
+ }
+ m_ValidateStats.BlockAttachmentCount = Result.BlockAttachments.size();
+
+ std::vector<ChunkBlockDescription> VerifyBlockDescriptions =
+ ParseChunkBlockDescriptionList(m_Storage.GetBlockMetadatas(m_BuildId, Result.BlockAttachments));
+ if (VerifyBlockDescriptions.size() != Result.BlockAttachments.size())
+ {
+ throw std::runtime_error(fmt::format("Uploaded blocks metadata could not all be found, {} blocks metadata is missing",
+ Result.BlockAttachments.size() - VerifyBlockDescriptions.size()));
+ }
+
+ return Result;
+}
+
+void
+BuildsOperationValidateBuildPart::ScheduleChunkAttachmentValidation(ValidateBlobsContext& Context,
+ std::span<const IoHash> ChunkAttachments,
+ const std::filesystem::path& TempFolder,
+ uint64_t PreferredMultipartChunkSize)
+{
+ for (const IoHash& ChunkAttachment : ChunkAttachments)
+ {
+ Context.Work.ScheduleWork(
+ m_NetworkPool,
+ [this, &Context, &TempFolder, PreferredMultipartChunkSize, ChunkAttachment = IoHash(ChunkAttachment)](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("ValidateBuildPart_GetChunk");
+
+ Context.FilteredDownloadedBytesPerSecond.Start();
+ DownloadLargeBlob(
+ m_Storage,
+ TempFolder,
+ m_BuildId,
+ ChunkAttachment,
+ PreferredMultipartChunkSize,
+ Context.Work,
+ m_NetworkPool,
+ m_DownloadStats.DownloadedChunkByteCount,
+ m_DownloadStats.MultipartAttachmentCount,
+ [this, &Context, ChunkHash = IoHash(ChunkAttachment)](IoBuffer&& Payload) {
+ m_DownloadStats.DownloadedChunkCount++;
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ if (!m_AbortFlag)
+ {
+ Context.Work.ScheduleWork(
+ m_IOWorkerPool,
+ [this, &Context, Payload = IoBuffer(std::move(Payload)), ChunkHash](std::atomic<bool>&) mutable {
+ if (!m_AbortFlag)
+ {
+ ValidateDownloadedChunk(Context, ChunkHash, std::move(Payload));
+ }
+ });
+ }
+ });
+ }
+ });
+ }
+}
+
+void
+BuildsOperationValidateBuildPart::ScheduleBlockAttachmentValidation(ValidateBlobsContext& Context, std::span<const IoHash> BlockAttachments)
+{
+ for (const IoHash& BlockAttachment : BlockAttachments)
+ {
+ Context.Work.ScheduleWork(m_NetworkPool, [this, &Context, BlockAttachment = IoHash(BlockAttachment)](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ ZEN_TRACE_CPU("ValidateBuildPart_GetBlock");
+
+ Context.FilteredDownloadedBytesPerSecond.Start();
+ IoBuffer Payload = m_Storage.GetBuildBlob(m_BuildId, BlockAttachment);
+ m_DownloadStats.DownloadedBlockCount++;
+ m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
+ if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == Context.AttachmentsToVerifyCount)
+ {
+ Context.FilteredDownloadedBytesPerSecond.Stop();
+ }
+ if (!Payload)
+ {
+ throw std::runtime_error(fmt::format("Block attachment {} could not be found", BlockAttachment));
+ }
+ if (!m_AbortFlag)
+ {
+ Context.Work.ScheduleWork(m_IOWorkerPool,
+ [this, &Context, Payload = std::move(Payload), BlockAttachment](std::atomic<bool>&) mutable {
+ if (!m_AbortFlag)
+ {
+ ValidateDownloadedBlock(Context, BlockAttachment, std::move(Payload));
+ }
+ });
+ }
+ }
+ });
+ }
+}
+
+void
+BuildsOperationValidateBuildPart::ValidateDownloadedChunk(ValidateBlobsContext& Context, const IoHash& ChunkHash, IoBuffer Payload)
+{
+ ZEN_TRACE_CPU("ValidateBuildPart_Validate");
+
+ if (m_DownloadStats.DownloadedChunkCount + m_DownloadStats.DownloadedBlockCount == Context.AttachmentsToVerifyCount)
+ {
+ Context.FilteredDownloadedBytesPerSecond.Stop();
+ }
+
+ Context.FilteredVerifiedBytesPerSecond.Start();
+
+ uint64_t CompressedSize;
+ uint64_t DecompressedSize;
+ ValidateBlob(m_AbortFlag, std::move(Payload), ChunkHash, CompressedSize, DecompressedSize);
+ m_ValidateStats.VerifiedAttachmentCount++;
+ m_ValidateStats.VerifiedByteCount += DecompressedSize;
+ if (m_ValidateStats.VerifiedAttachmentCount.load() == Context.AttachmentsToVerifyCount)
+ {
+ Context.FilteredVerifiedBytesPerSecond.Stop();
+ }
+}
+
+void
+BuildsOperationValidateBuildPart::ValidateDownloadedBlock(ValidateBlobsContext& Context, const IoHash& BlockAttachment, IoBuffer Payload)
+{
+ ZEN_TRACE_CPU("ValidateBuildPart_ValidateBlock");
+
+ Context.FilteredVerifiedBytesPerSecond.Start();
+
+ uint64_t CompressedSize;
+ uint64_t DecompressedSize;
+ ValidateChunkBlock(std::move(Payload), BlockAttachment, CompressedSize, DecompressedSize);
+ m_ValidateStats.VerifiedAttachmentCount++;
+ m_ValidateStats.VerifiedByteCount += DecompressedSize;
+ if (m_ValidateStats.VerifiedAttachmentCount.load() == Context.AttachmentsToVerifyCount)
+ {
+ Context.FilteredVerifiedBytesPerSecond.Stop();
+ }
+}
+
+BuildsOperationPrimeCache::BuildsOperationPrimeCache(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -6836,7 +7020,8 @@ BuildsOperationPrimeCache::BuildsOperationPrimeCache(OperationLogOutput& Opera
std::span<const Oid> BuildPartIds,
const Options& Options,
BuildStorageCache::Statistics& StorageCacheStats)
-: m_LogOutput(OperationLogOutput)
+: m_Log(Log)
+, m_Progress(Progress)
, m_Storage(Storage)
, m_AbortFlag(AbortFlag)
, m_PauseFlag(PauseFlag)
@@ -6858,9 +7043,69 @@ BuildsOperationPrimeCache::Execute()
Stopwatch PrimeTimer;
tsl::robin_map<IoHash, uint64_t, IoHash::Hasher> LooseChunkRawSizes;
+ tsl::robin_set<IoHash, IoHash::Hasher> BuildBlobs;
+ CollectReferencedBlobs(BuildBlobs, LooseChunkRawSizes);
+
+ if (!m_Options.IsQuiet)
+ {
+ ZEN_INFO("Found {} referenced blobs", BuildBlobs.size());
+ }
- tsl::robin_set<IoHash, IoHash::Hasher> BuildBlobs;
+ if (BuildBlobs.empty())
+ {
+ return;
+ }
+ std::vector<IoHash> BlobsToDownload = FilterAlreadyCachedBlobs(BuildBlobs);
+
+ if (BlobsToDownload.empty())
+ {
+ return;
+ }
+
+ std::atomic<uint64_t> MultipartAttachmentCount;
+ std::atomic<size_t> CompletedDownloadCount;
+ FilteredRate FilteredDownloadedBytesPerSecond;
+
+ ScheduleBlobDownloads(BlobsToDownload,
+ LooseChunkRawSizes,
+ MultipartAttachmentCount,
+ CompletedDownloadCount,
+ FilteredDownloadedBytesPerSecond);
+
+ if (m_AbortFlag)
+ {
+ return;
+ }
+
+ if (m_Storage.CacheStorage)
+ {
+ m_Storage.CacheStorage->Flush(m_Progress.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool {
+ ZEN_UNUSED(Remaining);
+ if (!m_Options.IsQuiet)
+ {
+ ZEN_INFO("Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name);
+ }
+ return !m_AbortFlag;
+ });
+ }
+
+ if (!m_Options.IsQuiet)
+ {
+ uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load();
+ ZEN_INFO("Downloaded {} ({}bits/s) in {}. {} as multipart. Completed in {}",
+ NiceBytes(DownloadedBytes),
+ NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)),
+ NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000),
+ MultipartAttachmentCount.load(),
+ NiceTimeSpanMs(PrimeTimer.GetElapsedTimeMs()));
+ }
+}
+
+void
+BuildsOperationPrimeCache::CollectReferencedBlobs(tsl::robin_set<IoHash, IoHash::Hasher>& OutBuildBlobs,
+ tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& OutLooseChunkRawSizes)
+{
for (const Oid& BuildPartId : m_BuildPartIds)
{
CbObject BuildPart = m_Storage.BuildStorage->GetBuildPart(m_BuildId, BuildPartId);
@@ -6878,26 +7123,20 @@ BuildsOperationPrimeCache::Execute()
ChunkRawSizes.size()));
}
- BuildBlobs.reserve(ChunkAttachments.size() + BlockAttachments.size());
- BuildBlobs.insert(BlockAttachments.begin(), BlockAttachments.end());
- BuildBlobs.insert(ChunkAttachments.begin(), ChunkAttachments.end());
+ OutBuildBlobs.reserve(ChunkAttachments.size() + BlockAttachments.size());
+ OutBuildBlobs.insert(BlockAttachments.begin(), BlockAttachments.end());
+ OutBuildBlobs.insert(ChunkAttachments.begin(), ChunkAttachments.end());
for (size_t ChunkAttachmentIndex = 0; ChunkAttachmentIndex < ChunkAttachments.size(); ChunkAttachmentIndex++)
{
- LooseChunkRawSizes.insert_or_assign(ChunkAttachments[ChunkAttachmentIndex], ChunkRawSizes[ChunkAttachmentIndex]);
+ OutLooseChunkRawSizes.insert_or_assign(ChunkAttachments[ChunkAttachmentIndex], ChunkRawSizes[ChunkAttachmentIndex]);
}
}
+}
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Found {} referenced blobs", BuildBlobs.size());
- }
-
- if (BuildBlobs.empty())
- {
- return;
- }
-
+std::vector<IoHash>
+BuildsOperationPrimeCache::FilterAlreadyCachedBlobs(const tsl::robin_set<IoHash, IoHash::Hasher>& BuildBlobs)
+{
std::vector<IoHash> BlobsToDownload;
BlobsToDownload.reserve(BuildBlobs.size());
@@ -6923,11 +7162,10 @@ BuildsOperationPrimeCache::Execute()
if (FoundCount > 0 && !m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Remote cache : Found {} out of {} needed blobs in {}",
- FoundCount,
- BuildBlobs.size(),
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Remote cache : Found {} out of {} needed blobs in {}",
+ FoundCount,
+ BuildBlobs.size(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
}
}
@@ -6935,169 +7173,170 @@ BuildsOperationPrimeCache::Execute()
{
BlobsToDownload.insert(BlobsToDownload.end(), BuildBlobs.begin(), BuildBlobs.end());
}
+ return BlobsToDownload;
+}
- if (BlobsToDownload.empty())
- {
- return;
- }
-
- std::atomic<uint64_t> MultipartAttachmentCount;
- std::atomic<size_t> CompletedDownloadCount;
- FilteredRate FilteredDownloadedBytesPerSecond;
-
- {
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Downloading"));
- OperationLogOutput::ProgressBar& Progress(*ProgressBarPtr);
-
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
-
- const size_t BlobCount = BlobsToDownload.size();
-
- for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++)
- {
- Work.ScheduleWork(m_NetworkPool,
- [this,
- &Work,
- &BlobsToDownload,
- BlobCount,
- &LooseChunkRawSizes,
- &CompletedDownloadCount,
- &FilteredDownloadedBytesPerSecond,
- &MultipartAttachmentCount,
- BlobIndex](std::atomic<bool>&) {
- if (!m_AbortFlag)
- {
- const IoHash& BlobHash = BlobsToDownload[BlobIndex];
-
- bool IsLargeBlob = false;
+void
+BuildsOperationPrimeCache::ScheduleBlobDownloads(std::span<const IoHash> BlobsToDownload,
+ const tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& LooseChunkRawSizes,
+ std::atomic<uint64_t>& MultipartAttachmentCount,
+ std::atomic<size_t>& CompletedDownloadCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond)
+{
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Downloading");
- if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end())
- {
- IsLargeBlob = It->second >= m_Options.LargeAttachmentSize;
- }
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
- FilteredDownloadedBytesPerSecond.Start();
+ const size_t BlobCount = BlobsToDownload.size();
- if (IsLargeBlob)
- {
- DownloadLargeBlob(
- *m_Storage.BuildStorage,
- m_TempPath,
- m_BuildId,
- BlobHash,
- m_Options.PreferredMultipartChunkSize,
- Work,
- m_NetworkPool,
- m_DownloadStats.DownloadedChunkByteCount,
- MultipartAttachmentCount,
- [this, BlobCount, BlobHash, &FilteredDownloadedBytesPerSecond, &CompletedDownloadCount](
- IoBuffer&& Payload) {
- m_DownloadStats.DownloadedChunkCount++;
- m_DownloadStats.RequestsCompleteCount++;
+ for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++)
+ {
+ Work.ScheduleWork(
+ m_NetworkPool,
+ [this,
+ &Work,
+ BlobsToDownload,
+ BlobCount,
+ &LooseChunkRawSizes,
+ &CompletedDownloadCount,
+ &FilteredDownloadedBytesPerSecond,
+ &MultipartAttachmentCount,
+ BlobIndex](std::atomic<bool>&) {
+ if (!m_AbortFlag)
+ {
+ const IoHash& BlobHash = BlobsToDownload[BlobIndex];
+ bool IsLargeBlob = false;
+ if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end())
+ {
+ IsLargeBlob = It->second >= m_Options.LargeAttachmentSize;
+ }
- if (!m_AbortFlag)
- {
- if (Payload && m_Storage.CacheStorage)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- BlobHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(Payload)));
- }
- }
- CompletedDownloadCount++;
- if (CompletedDownloadCount == BlobCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
- });
- }
- else
- {
- IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash);
- m_DownloadStats.DownloadedBlockCount++;
- m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
- m_DownloadStats.RequestsCompleteCount++;
+ FilteredDownloadedBytesPerSecond.Start();
- if (!m_AbortFlag)
- {
- if (Payload && m_Storage.CacheStorage)
- {
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- BlobHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(std::move(Payload))));
- }
- }
- CompletedDownloadCount++;
- if (CompletedDownloadCount == BlobCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
- }
- }
- }
- });
- }
+ if (IsLargeBlob)
+ {
+ DownloadLargeBlobForCache(Work,
+ BlobHash,
+ BlobCount,
+ CompletedDownloadCount,
+ MultipartAttachmentCount,
+ FilteredDownloadedBytesPerSecond);
+ }
+ else
+ {
+ DownloadSingleBlobForCache(BlobHash, BlobCount, CompletedDownloadCount, FilteredDownloadedBytesPerSecond);
+ }
+ }
+ });
+ }
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
- ZEN_UNUSED(PendingWork);
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ ZEN_UNUSED(PendingWork);
- uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load();
- FilteredDownloadedBytesPerSecond.Update(DownloadedBytes);
-
- std::string DownloadRateString = (CompletedDownloadCount == BlobCount)
- ? ""
- : fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8));
- std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.",
- m_StorageCacheStats.PutBlobCount.load(),
- NiceBytes(m_StorageCacheStats.PutBlobByteCount.load()))
- : "";
-
- std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}",
- CompletedDownloadCount.load(),
- BlobCount,
- NiceBytes(DownloadedBytes),
- DownloadRateString,
- UploadDetails);
- Progress.UpdateState({.Task = "Downloading",
+ uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load();
+ FilteredDownloadedBytesPerSecond.Update(DownloadedBytes);
+
+ std::string DownloadRateString = (CompletedDownloadCount == BlobCount)
+ ? ""
+ : fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8));
+ std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.",
+ m_StorageCacheStats.PutBlobCount.load(),
+ NiceBytes(m_StorageCacheStats.PutBlobByteCount.load()))
+ : "";
+
+ std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}",
+ CompletedDownloadCount.load(),
+ BlobCount,
+ NiceBytes(DownloadedBytes),
+ DownloadRateString,
+ UploadDetails);
+ ProgressBar->UpdateState({.Task = "Downloading",
.Details = Details,
.TotalCount = BlobCount,
.RemainingCount = BlobCount - CompletedDownloadCount.load(),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
false);
- });
+ });
+
+ FilteredDownloadedBytesPerSecond.Stop();
+ ProgressBar->Finish();
+}
- FilteredDownloadedBytesPerSecond.Stop();
+void
+BuildsOperationPrimeCache::DownloadLargeBlobForCache(ParallelWork& Work,
+ const IoHash& BlobHash,
+ size_t BlobCount,
+ std::atomic<size_t>& CompletedDownloadCount,
+ std::atomic<uint64_t>& MultipartAttachmentCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond)
+{
+ DownloadLargeBlob(*m_Storage.BuildStorage,
+ m_TempPath,
+ m_BuildId,
+ BlobHash,
+ m_Options.PreferredMultipartChunkSize,
+ Work,
+ m_NetworkPool,
+ m_DownloadStats.DownloadedChunkByteCount,
+ MultipartAttachmentCount,
+ [this, BlobCount, BlobHash, &FilteredDownloadedBytesPerSecond, &CompletedDownloadCount](IoBuffer&& Payload) {
+ m_DownloadStats.DownloadedChunkCount++;
+ m_DownloadStats.RequestsCompleteCount++;
+
+ if (!m_AbortFlag)
+ {
+ if (Payload && m_Storage.CacheStorage)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlobHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(Payload)));
+ }
+ }
+ if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
+ });
+}
- Progress.Finish();
- }
- if (m_AbortFlag)
+void
+BuildsOperationPrimeCache::DownloadSingleBlobForCache(const IoHash& BlobHash,
+ size_t BlobCount,
+ std::atomic<size_t>& CompletedDownloadCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond)
+{
+ IoBuffer Payload;
+ try
{
- return;
- }
+ Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash);
- if (m_Storage.CacheStorage)
+ m_DownloadStats.DownloadedBlockCount++;
+ m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
+ m_DownloadStats.RequestsCompleteCount++;
+ }
+ catch (const std::exception&)
{
- m_Storage.CacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool {
- ZEN_UNUSED(Remaining);
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name);
- }
- return !m_AbortFlag;
- });
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
- if (!m_Options.IsQuiet)
+ if (!m_AbortFlag)
{
- uint64_t DownloadedBytes = m_DownloadStats.DownloadedChunkByteCount.load() + m_DownloadStats.DownloadedBlockByteCount.load();
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Downloaded {} ({}bits/s) in {}. {} as multipart. Completed in {}",
- NiceBytes(DownloadedBytes),
- NiceNum(GetBytesPerSecond(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS(), DownloadedBytes * 8)),
- NiceTimeSpanMs(FilteredDownloadedBytesPerSecond.GetElapsedTimeUS() / 1000),
- MultipartAttachmentCount.load(),
- NiceTimeSpanMs(PrimeTimer.GetElapsedTimeMs()));
+ if (Payload && m_Storage.CacheStorage)
+ {
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlobHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(std::move(Payload))));
+ }
+ if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
}
}
@@ -7212,7 +7451,7 @@ ResolveBuildPartNames(CbObjectView BuildObject,
}
ChunkedFolderContent
-GetRemoteContent(OperationLogOutput& Output,
+GetRemoteContent(LoggerRef InLog,
StorageInstance& Storage,
const Oid& BuildId,
const std::vector<std::pair<Oid, std::string>>& BuildParts,
@@ -7228,6 +7467,7 @@ GetRemoteContent(OperationLogOutput& Output,
bool DoExtraContentVerify)
{
ZEN_TRACE_CPU("GetRemoteContent");
+ ZEN_SCOPED_LOG(InLog);
Stopwatch GetBuildPartTimer;
const Oid BuildPartId = BuildParts[0].first;
@@ -7235,13 +7475,12 @@ GetRemoteContent(OperationLogOutput& Output,
CbObject BuildPartManifest = Storage.BuildStorage->GetBuildPart(BuildId, BuildPartId);
if (!IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "GetBuildPart {} ('{}') took {}. Payload size: {}",
- BuildPartId,
- BuildPartName,
- NiceTimeSpanMs(GetBuildPartTimer.GetElapsedTimeMs()),
- NiceBytes(BuildPartManifest.GetSize()));
- ZEN_OPERATION_LOG_INFO(Output, "{}", GetCbObjectAsNiceString(BuildPartManifest, " "sv, "\n"sv));
+ ZEN_INFO("GetBuildPart {} ('{}') took {}. Payload size: {}",
+ BuildPartId,
+ BuildPartName,
+ NiceTimeSpanMs(GetBuildPartTimer.GetElapsedTimeMs()),
+ NiceBytes(BuildPartManifest.GetSize()));
+ ZEN_INFO("{}", GetCbObjectAsNiceString(BuildPartManifest, " "sv, "\n"sv));
}
{
@@ -7251,17 +7490,16 @@ GetRemoteContent(OperationLogOutput& Output,
OutChunkController = CreateChunkingController(ChunkerName, Parameters);
}
- auto ParseBuildPartManifest = [&Output, IsQuiet, IsVerbose, DoExtraContentVerify](
- StorageInstance& Storage,
- const Oid& BuildId,
- const Oid& BuildPartId,
- CbObject BuildPartManifest,
- std::span<const std::string> IncludeWildcards,
- std::span<const std::string> ExcludeWildcards,
- const BuildManifest::Part* OptionalManifest,
- ChunkedFolderContent& OutRemoteContent,
- std::vector<ChunkBlockDescription>& OutBlockDescriptions,
- std::vector<IoHash>& OutLooseChunkHashes) {
+ auto ParseBuildPartManifest = [&Log, IsQuiet, IsVerbose, DoExtraContentVerify](StorageInstance& Storage,
+ const Oid& BuildId,
+ const Oid& BuildPartId,
+ CbObject BuildPartManifest,
+ std::span<const std::string> IncludeWildcards,
+ std::span<const std::string> ExcludeWildcards,
+ const BuildManifest::Part* OptionalManifest,
+ ChunkedFolderContent& OutRemoteContent,
+ std::vector<ChunkBlockDescription>& OutBlockDescriptions,
+ std::vector<IoHash>& OutLooseChunkHashes) {
std::vector<uint32_t> AbsoluteChunkOrders;
std::vector<uint64_t> LooseChunkRawSizes;
std::vector<IoHash> BlockRawHashes;
@@ -7284,13 +7522,13 @@ GetRemoteContent(OperationLogOutput& Output,
{
if (!IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size());
+ ZEN_INFO("Fetching metadata for {} blocks", BlockRawHashes.size());
}
Stopwatch GetBlockMetadataTimer;
bool AttemptFallback = false;
- OutBlockDescriptions = GetBlockDescriptions(Output,
+ OutBlockDescriptions = GetBlockDescriptions(Log(),
*Storage.BuildStorage,
Storage.CacheStorage.get(),
BuildId,
@@ -7301,11 +7539,10 @@ GetRemoteContent(OperationLogOutput& Output,
if (!IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "GetBlockMetadata for {} took {}. Found {} blocks",
- BuildPartId,
- NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()),
- OutBlockDescriptions.size());
+ ZEN_INFO("GetBlockMetadata for {} took {}. Found {} blocks",
+ BuildPartId,
+ NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()),
+ OutBlockDescriptions.size());
}
}
@@ -7414,12 +7651,11 @@ GetRemoteContent(OperationLogOutput& Output,
CbObject OverlayBuildPartManifest = Storage.BuildStorage->GetBuildPart(BuildId, OverlayBuildPartId);
if (!IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "GetBuildPart {} ('{}') took {}. Payload size: {}",
- OverlayBuildPartId,
- OverlayBuildPartName,
- NiceTimeSpanMs(GetOverlayBuildPartTimer.GetElapsedTimeMs()),
- NiceBytes(OverlayBuildPartManifest.GetSize()));
+ ZEN_INFO("GetBuildPart {} ('{}') took {}. Payload size: {}",
+ OverlayBuildPartId,
+ OverlayBuildPartName,
+ NiceTimeSpanMs(GetOverlayBuildPartTimer.GetElapsedTimeMs()),
+ NiceBytes(OverlayBuildPartManifest.GetSize()));
}
ChunkedFolderContent OverlayPartContent;
@@ -7589,7 +7825,7 @@ namespace buildstorageoperations_testutils {
{
TestState(const std::filesystem::path& InRootPath)
: RootPath(InRootPath)
- , LogOutput(CreateStandardLogOutput(Log))
+ , LogOutput(CreateStandardProgress(Log))
, ChunkController(CreateStandardChunkingController(StandardChunkingControllerSettings{}))
, ChunkCache(CreateMemoryChunkingCache())
, WorkerPool(2)
@@ -7631,7 +7867,8 @@ namespace buildstorageoperations_testutils {
{
const std::filesystem::path SourcePath = RootPath / Source;
CbObject MetaData;
- BuildsOperationUploadFolder Upload(*LogOutput,
+ BuildsOperationUploadFolder Upload(Log,
+ *LogOutput,
Storage,
AbortFlag,
PauseFlag,
@@ -7649,7 +7886,8 @@ namespace buildstorageoperations_testutils {
{
for (auto Part : Parts)
{
- BuildsOperationValidateBuildPart Validate(*LogOutput,
+ BuildsOperationValidateBuildPart Validate(Log,
+ *LogOutput,
*Storage.BuildStorage,
AbortFlag,
PauseFlag,
@@ -7693,7 +7931,7 @@ namespace buildstorageoperations_testutils {
std::vector<ChunkBlockDescription> BlockDescriptions;
std::vector<IoHash> LooseChunkHashes;
- ChunkedFolderContent RemoteContent = GetRemoteContent(*LogOutput,
+ ChunkedFolderContent RemoteContent = GetRemoteContent(Log,
Storage,
BuildId,
AllBuildParts,
@@ -7774,7 +8012,8 @@ namespace buildstorageoperations_testutils {
const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalContent);
const ChunkedContentLookup RemoteLookup = BuildChunkedContentLookup(RemoteContent);
- BuildsOperationUpdateFolder Download(*LogOutput,
+ BuildsOperationUpdateFolder Download(Log,
+ *LogOutput,
Storage,
AbortFlag,
PauseFlag,
@@ -7837,8 +8076,8 @@ namespace buildstorageoperations_testutils {
std::filesystem::path SystemRootDir;
std::filesystem::path ZenFolderPath;
- LoggerRef Log = ConsoleLog();
- std::unique_ptr<OperationLogOutput> LogOutput;
+ LoggerRef Log = ConsoleLog();
+ std::unique_ptr<ProgressBase> LogOutput;
std::unique_ptr<ChunkingController> ChunkController;
std::unique_ptr<ChunkingCache> ChunkCache;
@@ -7990,7 +8229,8 @@ TEST_CASE("buildstorageoperations.memorychunkingcache")
{
const std::filesystem::path SourcePath = SourceFolder.Path() / "source";
CbObject MetaData;
- BuildsOperationUploadFolder Upload(*State.LogOutput,
+ BuildsOperationUploadFolder Upload(State.Log,
+ *State.LogOutput,
State.Storage,
State.AbortFlag,
State.PauseFlag,
@@ -8020,7 +8260,8 @@ TEST_CASE("buildstorageoperations.memorychunkingcache")
{
const std::filesystem::path SourcePath = SourceFolder.Path() / "source";
CbObject MetaData;
- BuildsOperationUploadFolder Upload(*State.LogOutput,
+ BuildsOperationUploadFolder Upload(State.Log,
+ *State.LogOutput,
State.Storage,
State.AbortFlag,
State.PauseFlag,
diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp
index 2ae726e29..144964e37 100644
--- a/src/zenremotestore/builds/buildstorageutil.cpp
+++ b/src/zenremotestore/builds/buildstorageutil.cpp
@@ -9,7 +9,6 @@
#include <zenremotestore/builds/jupiterbuildstorage.h>
#include <zenremotestore/chunking/chunkblock.h>
#include <zenremotestore/jupiter/jupiterhost.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenutil/zenserverprocess.h>
namespace zen {
@@ -32,7 +31,7 @@ namespace {
} // namespace
BuildStorageResolveResult
-ResolveBuildStorage(OperationLogOutput& Output,
+ResolveBuildStorage(LoggerRef InLog,
const HttpClientSettings& ClientSettings,
std::string_view Host,
std::string_view OverrideHost,
@@ -40,6 +39,8 @@ ResolveBuildStorage(OperationLogOutput& Output,
ZenCacheResolveMode ZenResolveMode,
bool Verbose)
{
+ ZEN_SCOPED_LOG(InLog);
+
bool AllowZenCacheDiscovery = ZenResolveMode == ZenCacheResolveMode::Discovery || ZenResolveMode == ZenCacheResolveMode::All;
bool AllowLocalZenCache = ZenResolveMode == ZenCacheResolveMode::LocalHost || ZenResolveMode == ZenCacheResolveMode::All;
@@ -80,10 +81,9 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "Querying servers at '{}/api/v1/status/servers'\n Connection settings:{}",
- DiscoveryHost,
- ConnectionSettingsToString(ClientSettings));
+ ZEN_INFO("Querying servers at '{}/api/v1/status/servers'\n Connection settings:{}",
+ DiscoveryHost,
+ ConnectionSettingsToString(ClientSettings));
}
DiscoveryResponse = DiscoverJupiterEndpoints(DiscoveryHost, ClientSettings);
@@ -93,14 +93,14 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output, "Testing server endpoint at '{}/health/live'. Assume http2: {}", OverrideHost, HostAssumeHttp2);
+ ZEN_INFO("Testing server endpoint at '{}/health/live'. Assume http2: {}", OverrideHost, HostAssumeHttp2);
}
if (JupiterEndpointTestResult TestResult = TestJupiterEndpoint(OverrideHost, HostAssumeHttp2, ClientSettings.Verbose);
TestResult.Success)
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost);
+ ZEN_INFO("Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost);
}
HostUrl = OverrideHost;
HostName = GetHostNameFromUrl(OverrideHost);
@@ -125,10 +125,9 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "Testing server endpoint at '{}/health/live'. Assume http2: {}",
- ServerEndpoint.BaseUrl,
- ServerEndpoint.AssumeHttp2);
+ ZEN_INFO("Testing server endpoint at '{}/health/live'. Assume http2: {}",
+ ServerEndpoint.BaseUrl,
+ ServerEndpoint.AssumeHttp2);
}
if (JupiterEndpointTestResult TestResult =
@@ -137,7 +136,7 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl);
+ ZEN_INFO("Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl);
}
HostUrl = ServerEndpoint.BaseUrl;
@@ -149,10 +148,7 @@ ResolveBuildStorage(OperationLogOutput& Output,
}
else
{
- ZEN_OPERATION_LOG_DEBUG(Output,
- "Unable to reach host {}. Reason: {}",
- ServerEndpoint.BaseUrl,
- TestResult.FailureReason);
+ ZEN_DEBUG("Unable to reach host {}. Reason: {}", ServerEndpoint.BaseUrl, TestResult.FailureReason);
}
}
}
@@ -173,10 +169,9 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "Testing cache endpoint at '{}/status/builds'. Assume http2: {}",
- CacheEndpoint.BaseUrl,
- CacheEndpoint.AssumeHttp2);
+ ZEN_INFO("Testing cache endpoint at '{}/status/builds'. Assume http2: {}",
+ CacheEndpoint.BaseUrl,
+ CacheEndpoint.AssumeHttp2);
}
if (ZenCacheEndpointTestResult TestResult =
@@ -185,7 +180,7 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output, "Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl);
+ ZEN_INFO("Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl);
}
CacheUrl = CacheEndpoint.BaseUrl;
@@ -225,7 +220,7 @@ ResolveBuildStorage(OperationLogOutput& Output,
{
if (Verbose)
{
- ZEN_OPERATION_LOG_INFO(Output, "Testing cache endpoint at '{}/status/builds'. Assume http2: {}", ZenCacheHost, false);
+ ZEN_INFO("Testing cache endpoint at '{}/status/builds'. Assume http2: {}", ZenCacheHost, false);
}
if (ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(ZenCacheHost, /*AssumeHttp2*/ false, ClientSettings.Verbose);
TestResult.Success)
@@ -272,7 +267,7 @@ ParseBlockMetadatas(std::span<const CbObject> BlockMetadatas)
}
std::vector<ChunkBlockDescription>
-GetBlockDescriptions(OperationLogOutput& Output,
+GetBlockDescriptions(LoggerRef InLog,
BuildStorageBase& Storage,
BuildStorageCache* OptionalCacheStorage,
const Oid& BuildId,
@@ -282,6 +277,7 @@ GetBlockDescriptions(OperationLogOutput& Output,
bool IsVerbose)
{
using namespace std::literals;
+ ZEN_SCOPED_LOG(InLog);
std::vector<ChunkBlockDescription> UnorderedList;
tsl::robin_map<IoHash, size_t, IoHash::Hasher> BlockDescriptionLookup;
@@ -322,7 +318,7 @@ GetBlockDescriptions(OperationLogOutput& Output,
if (Description.BlockHash == IoHash::Zero)
{
- ZEN_OPERATION_LOG_WARN(Output, "Unexpected/invalid block metadata received from remote store, skipping block");
+ ZEN_WARN("Unexpected/invalid block metadata received from remote store, skipping block");
}
else
{
@@ -383,7 +379,7 @@ GetBlockDescriptions(OperationLogOutput& Output,
}
if (AttemptFallback)
{
- ZEN_OPERATION_LOG_WARN(Output, "{} Attemping fallback options.", ErrorDescription);
+ ZEN_WARN("{} Attemping fallback options.", ErrorDescription);
std::vector<ChunkBlockDescription> AugmentedBlockDescriptions;
AugmentedBlockDescriptions.reserve(BlockRawHashes.size());
std::vector<ChunkBlockDescription> FoundBlocks = ParseChunkBlockDescriptionList(Storage.FindBlocks(BuildId, (uint64_t)-1));
@@ -406,7 +402,7 @@ GetBlockDescriptions(OperationLogOutput& Output,
{
if (!IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(Output, "Found block {} via context find successfully", BlockHash);
+ ZEN_INFO("Found block {} via context find successfully", BlockHash);
}
AugmentedBlockDescriptions.emplace_back(std::move(*ListBlocksIt));
}
diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp
index ad4c4bc89..d837ce07f 100644
--- a/src/zenremotestore/builds/jupiterbuildstorage.cpp
+++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp
@@ -263,7 +263,7 @@ public:
std::vector<std::function<void()>> WorkList;
for (auto& WorkItem : WorkItems)
{
- WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes]() {
+ WorkList.emplace_back([this, WorkItem = std::move(WorkItem), OnSentBytes = std::move(OnSentBytes)]() {
Stopwatch ExecutionTimer;
auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); });
bool IsComplete = false;
@@ -444,11 +444,13 @@ public:
virtual bool GetExtendedStatistics(ExtendedStatistics& OutStats) override
{
- OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size());
- for (auto& It : m_ReceivedBytesPerSource)
- {
- OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second]);
- }
+ m_SourceLock.WithSharedLock([this, &OutStats]() {
+ OutStats.ReceivedBytesPerSource.reserve(m_ReceivedBytesPerSource.size());
+ for (auto& It : m_ReceivedBytesPerSource)
+ {
+ OutStats.ReceivedBytesPerSource.insert_or_assign(It.first, m_SourceBytes[It.second].load(std::memory_order_relaxed));
+ }
+ });
return true;
}
@@ -521,15 +523,29 @@ private:
}
if (!Result.Source.empty())
{
- if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source);
- It != m_ReceivedBytesPerSource.end())
- {
- m_SourceBytes[It->second] += Result.ReceivedBytes;
- }
- else
+ if (!m_SourceLock.WithSharedLock([&]() {
+ if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source);
+ It != m_ReceivedBytesPerSource.end())
+ {
+ m_SourceBytes[It->second] += Result.ReceivedBytes;
+ return true;
+ }
+ return false;
+ }))
{
- m_ReceivedBytesPerSource.insert_or_assign(Result.Source, m_SourceBytes.size());
- m_SourceBytes.push_back(Result.ReceivedBytes);
+ m_SourceLock.WithExclusiveLock([&]() {
+ if (tsl::robin_map<std::string, uint32_t>::const_iterator It = m_ReceivedBytesPerSource.find(Result.Source);
+ It != m_ReceivedBytesPerSource.end())
+ {
+ m_SourceBytes[It->second] += Result.ReceivedBytes;
+ }
+ else if (m_SourceCount < MaxSourceCount)
+ {
+ size_t Index = m_SourceCount++;
+ m_ReceivedBytesPerSource.insert_or_assign(Result.Source, Index);
+ m_SourceBytes[Index] += Result.ReceivedBytes;
+ }
+ });
}
}
}
@@ -540,8 +556,11 @@ private:
const std::string m_Bucket;
const std::filesystem::path m_TempFolderPath;
- tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource;
- std::vector<uint64_t> m_SourceBytes;
+ RwLock m_SourceLock;
+ tsl::robin_map<std::string, uint32_t> m_ReceivedBytesPerSource;
+ static constexpr size_t MaxSourceCount = 8u;
+ std::array<std::atomic<uint64_t>, MaxSourceCount> m_SourceBytes;
+ size_t m_SourceCount = 0;
};
std::unique_ptr<BuildStorageBase>
diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp
index 0fe3c09ce..f29112f53 100644
--- a/src/zenremotestore/chunking/chunkblock.cpp
+++ b/src/zenremotestore/chunking/chunkblock.cpp
@@ -7,7 +7,6 @@
#include <zencore/logging.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
-#include <zenremotestore/operationlogoutput.h>
#include <numeric>
@@ -445,7 +444,7 @@ IterateChunkBlock(const SharedBuffer& BlockPayload,
};
std::vector<size_t>
-FindReuseBlocks(OperationLogOutput& Output,
+FindReuseBlocks(LoggerRef InLog,
const uint8_t BlockReuseMinPercentLimit,
const bool IsVerbose,
ReuseBlocksStatistics& Stats,
@@ -455,6 +454,7 @@ FindReuseBlocks(OperationLogOutput& Output,
std::vector<uint32_t>& OutUnusedChunkIndexes)
{
ZEN_TRACE_CPU("FindReuseBlocks");
+ ZEN_SCOPED_LOG(InLog);
// Find all blocks with a usage level higher than MinPercentLimit
// Pick out the blocks with usage higher or equal to MinPercentLimit
@@ -521,11 +521,10 @@ FindReuseBlocks(OperationLogOutput& Output,
{
if (IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "Reusing block {}. {} attachments found, usage level: {}%",
- KnownBlock.BlockHash,
- FoundAttachmentCount,
- ReusePercent);
+ ZEN_INFO("Reusing block {}. {} attachments found, usage level: {}%",
+ KnownBlock.BlockHash,
+ FoundAttachmentCount,
+ ReusePercent);
}
ReuseBlockIndexes.push_back(KnownBlockIndex);
@@ -534,12 +533,13 @@ FindReuseBlocks(OperationLogOutput& Output,
}
else if (FoundAttachmentCount > 0)
{
- // if (IsVerbose)
- //{
- // ZEN_OPERATION_LOG_INFO(Output, "Skipping block {}. {} attachments found, usage level: {}%",
- // KnownBlock.BlockHash,
- // FoundAttachmentCount, ReusePercent);
- //}
+ if (IsVerbose)
+ {
+ ZEN_INFO("Skipping block {}. {} attachments found, usage level: {}%",
+ KnownBlock.BlockHash,
+ FoundAttachmentCount,
+ ReusePercent);
+ }
Stats.RejectedBlockCount++;
Stats.RejectedChunkCount += FoundAttachmentCount;
Stats.RejectedByteCount += ReuseSize;
@@ -583,11 +583,10 @@ FindReuseBlocks(OperationLogOutput& Output,
{
if (IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(Output,
- "Reusing block {}. {} attachments found, usage level: {}%",
- KnownBlock.BlockHash,
- FoundChunkIndexes.size(),
- ReusePercent);
+ ZEN_INFO("Reusing block {}. {} attachments found, usage level: {}%",
+ KnownBlock.BlockHash,
+ FoundChunkIndexes.size(),
+ ReusePercent);
}
FilteredReuseBlockIndexes.push_back(KnownBlockIndex);
@@ -604,11 +603,10 @@ FindReuseBlocks(OperationLogOutput& Output,
}
else
{
- // if (IsVerbose)
- //{
- // ZEN_OPERATION_LOG_INFO(Output, "Skipping block {}. filtered usage level: {}%", KnownBlock.BlockHash,
- // ReusePercent);
- //}
+ if (IsVerbose)
+ {
+ ZEN_INFO("Skipping block {}. filtered usage level: {}%", KnownBlock.BlockHash, ReusePercent);
+ }
Stats.RejectedBlockCount++;
Stats.RejectedChunkCount += FoundChunkIndexes.size();
Stats.RejectedByteCount += AdjustedReuseSize;
@@ -629,10 +627,8 @@ FindReuseBlocks(OperationLogOutput& Output,
return FilteredReuseBlockIndexes;
}
-ChunkBlockAnalyser::ChunkBlockAnalyser(OperationLogOutput& LogOutput,
- std::span<const ChunkBlockDescription> BlockDescriptions,
- const Options& Options)
-: m_LogOutput(LogOutput)
+ChunkBlockAnalyser::ChunkBlockAnalyser(LoggerRef Log, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options)
+: m_Log(Log)
, m_BlockDescriptions(BlockDescriptions)
, m_Options(Options)
{
@@ -899,20 +895,20 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span<const NeededBlock>
if (Ideal < FullDownloadTotalSize && !m_Options.IsQuiet)
{
const double AchievedPercent = Ideal == 0 ? 100.0 : (100.0 * Actual) / Ideal;
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Block Partial Analysis: Blocks: {}, Full: {}, Ideal: {}, Actual: {}. Skipping {} ({:.1f}%) out of "
- "possible {} using {} extra ranges "
- "via {} extra requests. Completed in {}",
- NeededBlocks.size(),
- NiceBytes(FullDownloadTotalSize),
- NiceBytes(IdealDownloadTotalSize),
- NiceBytes(ActualDownloadTotalSize),
- NiceBytes(FullDownloadTotalSize - ActualDownloadTotalSize),
- AchievedPercent,
- NiceBytes(Ideal),
- RangeCount - MinRequestCount,
- RequestCount - MinRequestCount,
- NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs()));
+ ZEN_INFO(
+ "Block Partial Analysis: Blocks: {}, Full: {}, Ideal: {}, Actual: {}. Skipping {} ({:.1f}%) out of "
+ "possible {} using {} extra ranges "
+ "via {} extra requests. Completed in {}",
+ NeededBlocks.size(),
+ NiceBytes(FullDownloadTotalSize),
+ NiceBytes(IdealDownloadTotalSize),
+ NiceBytes(ActualDownloadTotalSize),
+ NiceBytes(FullDownloadTotalSize - ActualDownloadTotalSize),
+ AchievedPercent,
+ NiceBytes(Ideal),
+ RangeCount - MinRequestCount,
+ RequestCount - MinRequestCount,
+ NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs()));
}
}
@@ -1001,8 +997,7 @@ TEST_CASE("chunkblock.reuseblocks")
BlockDescriptions.emplace_back(std::move(Block));
}
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
{
// We use just about all the chunks - should result in use of both blocks
@@ -1019,14 +1014,8 @@ TEST_CASE("chunkblock.reuseblocks")
std::iota(ManyChunkIndexes.begin(), ManyChunkIndexes.end(), 0);
std::vector<uint32_t> UnusedChunkIndexes;
- std::vector<size_t> ReusedBlocks = FindReuseBlocks(*LogOutput,
- 80,
- false,
- ReuseBlocksStats,
- BlockDescriptions,
- ManyChunkHashes,
- ManyChunkIndexes,
- UnusedChunkIndexes);
+ std::vector<size_t> ReusedBlocks =
+ FindReuseBlocks(LogRef, 80, false, ReuseBlocksStats, BlockDescriptions, ManyChunkHashes, ManyChunkIndexes, UnusedChunkIndexes);
CHECK_EQ(2u, ReusedBlocks.size());
CHECK_EQ(0u, UnusedChunkIndexes.size());
@@ -1047,7 +1036,7 @@ TEST_CASE("chunkblock.reuseblocks")
std::iota(ManyChunkIndexes.begin(), ManyChunkIndexes.end(), 0);
std::vector<uint32_t> UnusedChunkIndexes;
- std::vector<size_t> ReusedBlocks = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks = FindReuseBlocks(LogRef,
80,
false,
ReuseBlocksStats,
@@ -1076,7 +1065,7 @@ TEST_CASE("chunkblock.reuseblocks")
// We use half the chunks - should result in no use of blocks due to 80% limit
std::vector<uint32_t> UnusedChunkIndexes80Percent;
ReuseBlocksStatistics ReuseBlocksStats;
- std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(LogRef,
80,
false,
ReuseBlocksStats,
@@ -1092,7 +1081,7 @@ TEST_CASE("chunkblock.reuseblocks")
// We use half the chunks - should result in use of both blocks due to 40% limit
std::vector<uint32_t> UnusedChunkIndexes40Percent;
ReuseBlocksStatistics ReuseBlocksStats;
- std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(LogRef,
40,
false,
ReuseBlocksStats,
@@ -1122,7 +1111,7 @@ TEST_CASE("chunkblock.reuseblocks")
// We use half the chunks for first block - should result in use of one blocks due to 80% limit
ReuseBlocksStatistics ReuseBlocksStats;
std::vector<uint32_t> UnusedChunkIndexes80Percent;
- std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(LogRef,
80,
false,
ReuseBlocksStats,
@@ -1139,7 +1128,7 @@ TEST_CASE("chunkblock.reuseblocks")
// We use half the chunks - should result in use of both blocks due to 40% limit
ReuseBlocksStatistics ReuseBlocksStats;
std::vector<uint32_t> UnusedChunkIndexes40Percent;
- std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(LogRef,
40,
false,
ReuseBlocksStats,
@@ -1178,7 +1167,7 @@ TEST_CASE("chunkblock.reuseblocks")
// We use half the chunks for first block - should result in use of one blocks due to 80% limit
ReuseBlocksStatistics ReuseBlocksStats;
std::vector<uint32_t> UnusedChunkIndexes80Percent;
- std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks80Percent = FindReuseBlocks(LogRef,
80,
false,
ReuseBlocksStats,
@@ -1195,7 +1184,7 @@ TEST_CASE("chunkblock.reuseblocks")
// We use half the chunks - should result in use of both blocks due to 40% limit
ReuseBlocksStatistics ReuseBlocksStats;
std::vector<uint32_t> UnusedChunkIndexes40Percent;
- std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(*LogOutput,
+ std::vector<size_t> ReusedBlocks40Percent = FindReuseBlocks(LogRef,
40,
false,
ReuseBlocksStats,
@@ -1214,7 +1203,7 @@ namespace chunkblock_analyser_testutils {
// Build a ChunkBlockDescription without any real payload.
// Hashes are derived deterministically from (BlockSeed XOR ChunkIndex) so that the same
- // seed produces the same hashes — useful for deduplication tests.
+ // seed produces the same hashes - useful for deduplication tests.
static ChunkBlockDescription MakeBlockDesc(uint64_t HeaderSize,
std::initializer_list<uint32_t> CompressedLengths,
uint32_t BlockSeed = 0)
@@ -1257,7 +1246,7 @@ namespace chunkblock_analyser_testutils {
TEST_CASE("chunkblock.mergecheapestrange.picks_smallest_gap")
{
using RD = chunkblock_impl::RangeDescriptor;
- // Gap between ranges 0-1 is 50, gap between 1-2 is 150 → pair 0-1 gets merged
+ // Gap between ranges 0-1 is 50, gap between 1-2 is 150 -> pair 0-1 gets merged
std::vector<RD> Ranges = {
{.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
{.RangeStart = 150, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
@@ -1279,7 +1268,7 @@ TEST_CASE("chunkblock.mergecheapestrange.picks_smallest_gap")
TEST_CASE("chunkblock.mergecheapestrange.tiebreak_smaller_merged")
{
using RD = chunkblock_impl::RangeDescriptor;
- // Gap 0-1 == gap 1-2 == 100; merged size 0-1 (250) < merged size 1-2 (350) → pair 0-1 wins
+ // Gap 0-1 == gap 1-2 == 100; merged size 0-1 (250) < merged size 1-2 (350) -> pair 0-1 wins
std::vector<RD> Ranges = {
{.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
{.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
@@ -1304,7 +1293,7 @@ TEST_CASE("chunkblock.optimizeranges.preserves_ranges_low_latency")
{
using RD = chunkblock_impl::RangeDescriptor;
// With MaxRangeCountPerRequest unlimited, RequestCount=1
- // RequestTimeAsBytes = 100000 * 1 * 0.001 = 100 << slack=7000 → all ranges preserved
+ // RequestTimeAsBytes = 100000 * 1 * 0.001 = 100 << slack=7000 -> all ranges preserved
std::vector<RD> ExactRanges = {
{.RangeStart = 0, .RangeLength = 1000, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
{.RangeStart = 2000, .RangeLength = 1000, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
@@ -1325,7 +1314,7 @@ TEST_CASE("chunkblock.optimizeranges.preserves_ranges_low_latency")
TEST_CASE("chunkblock.optimizeranges.falls_back_to_full_block")
{
using RD = chunkblock_impl::RangeDescriptor;
- // 1 range already; slack=100 < SpeedBytesPerSec*LatencySec=200 → full block (empty result)
+ // 1 range already; slack=100 < SpeedBytesPerSec*LatencySec=200 -> full block (empty result)
std::vector<RD> ExactRanges = {
{.RangeStart = 100, .RangeLength = 900, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 3},
};
@@ -1344,7 +1333,7 @@ TEST_CASE("chunkblock.optimizeranges.falls_back_to_full_block")
TEST_CASE("chunkblock.optimizeranges.maxrangesperblock_clamp")
{
using RD = chunkblock_impl::RangeDescriptor;
- // 5 input ranges; MaxRangesPerBlock=2 clamps to ≤2 before the cost model runs
+ // 5 input ranges; MaxRangesPerBlock=2 clamps to <=2 before the cost model runs
std::vector<RD> ExactRanges = {
{.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
{.RangeStart = 300, .RangeLength = 100, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
@@ -1378,8 +1367,8 @@ TEST_CASE("chunkblock.optimizeranges.low_maxrangecountperrequest_drives_merge")
uint64_t TotalBlockSize = 1000;
double LatencySec = 1.0;
uint64_t SpeedBytesPerSec = 500;
- // With MaxRangeCountPerRequest=-1: RequestCount=1, RequestTimeAsBytes=500 < slack=700 → preserved
- // With MaxRangeCountPerRequest=1: RequestCount=3, RequestTimeAsBytes=1500 > slack=700 → merged
+ // With MaxRangeCountPerRequest=-1: RequestCount=1, RequestTimeAsBytes=500 < slack=700 -> preserved
+ // With MaxRangeCountPerRequest=1: RequestCount=3, RequestTimeAsBytes=1500 > slack=700 -> merged
uint64_t MaxRangesPerBlock = 1024;
auto Unlimited =
@@ -1394,7 +1383,7 @@ TEST_CASE("chunkblock.optimizeranges.low_maxrangecountperrequest_drives_merge")
TEST_CASE("chunkblock.optimizeranges.unlimited_rangecountperrequest_no_extra_cost")
{
using RD = chunkblock_impl::RangeDescriptor;
- // MaxRangeCountPerRequest=-1 → RequestCount always 1, even with many ranges and high latency
+ // MaxRangeCountPerRequest=-1 -> RequestCount always 1, even with many ranges and high latency
std::vector<RD> ExactRanges = {
{.RangeStart = 0, .RangeLength = 50, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 1},
{.RangeStart = 200, .RangeLength = 50, .ChunkBlockIndexStart = 1, .ChunkBlockIndexCount = 1},
@@ -1418,7 +1407,7 @@ TEST_CASE("chunkblock.optimizeranges.two_range_direct_merge_path")
{
using RD = chunkblock_impl::RangeDescriptor;
// Exactly 2 ranges; cost model demands merge; exercises the RangeCount==2 direct-merge branch
- // After direct merge → 1 range with small slack → full block (empty)
+ // After direct merge -> 1 range with small slack -> full block (empty)
std::vector<RD> ExactRanges = {
{.RangeStart = 0, .RangeLength = 100, .ChunkBlockIndexStart = 0, .ChunkBlockIndexCount = 2},
{.RangeStart = 400, .RangeLength = 100, .ChunkBlockIndexStart = 2, .ChunkBlockIndexCount = 2},
@@ -1429,8 +1418,8 @@ TEST_CASE("chunkblock.optimizeranges.two_range_direct_merge_path")
uint64_t MaxRangeCountPerReq = (uint64_t)-1;
uint64_t MaxRangesPerBlock = 1024;
- // Iteration 1: RangeCount=2, RequestCount=1, RequestTimeAsBytes=500 > slack=400 → direct merge
- // After merge: 1 range [{0,500,0,4}], slack=100 < Speed*Lat=500 → full block
+ // Iteration 1: RangeCount=2, RequestCount=1, RequestTimeAsBytes=500 > slack=400 -> direct merge
+ // After merge: 1 range [{0,500,0,4}], slack=100 < Speed*Lat=500 -> full block
auto Result =
chunkblock_impl::OptimizeRanges(TotalBlockSize, ExactRanges, LatencySec, SpeedBytesPerSec, MaxRangeCountPerReq, MaxRangesPerBlock);
@@ -1441,12 +1430,11 @@ TEST_CASE("chunkblock.getneeded.all_chunks")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
auto HashMap = MakeHashMap({Block});
auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return true; });
@@ -1464,12 +1452,11 @@ TEST_CASE("chunkblock.getneeded.no_chunks")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
auto HashMap = MakeHashMap({Block});
auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t) { return false; });
@@ -1481,12 +1468,11 @@ TEST_CASE("chunkblock.getneeded.subset_within_block")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
auto HashMap = MakeHashMap({Block});
// Indices 0 and 2 are needed; 1 and 3 are not
@@ -1503,12 +1489,11 @@ TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
- // Block 0: {H0, H1, SharedH, H3} — 3 of 4 needed (H3 not needed); slack = 100
- // Block 1: {H4, H5, SharedH, H6} — only SharedH needed; slack = 300
- // Block 0 has less slack → processed first → SharedH assigned to block 0
+ // Block 0: {H0, H1, SharedH, H3} - 3 of 4 needed (H3 not needed); slack = 100
+ // Block 1: {H4, H5, SharedH, H6} - only SharedH needed; slack = 300
+ // Block 0 has less slack -> processed first -> SharedH assigned to block 0
IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_dedup", 18));
IoHash H0 = IoHash::HashBuffer(MemoryView("block0_chunk0", 13));
IoHash H1 = IoHash::HashBuffer(MemoryView("block0_chunk1", 13));
@@ -1531,9 +1516,9 @@ TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins")
std::vector<ChunkBlockDescription> Blocks = {Block0, Block1};
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+ ChunkBlockAnalyser Analyser(LogRef, Blocks, Options);
- // Map: H0→0, H1→1, SharedH→2, H3→3, H4→4, H5→5, H6→6
+ // Map: H0->0, H1->1, SharedH->2, H3->3, H4->4, H5->5, H6->6
auto HashMap = MakeHashMap(Blocks);
// Need H0(0), H1(1), SharedH(2) from block 0; SharedH from block 1 (already index 2)
// H3(3) not needed; H4,H5,H6 not needed
@@ -1541,7 +1526,7 @@ TEST_CASE("chunkblock.getneeded.dedup_low_slack_wins")
// Block 0 slack=100 (H3 unused), block 1 slack=300 (H4,H5,H6 unused)
// Block 0 processed first; picks up H0, H1, SharedH
- // Block 1 tries SharedH but it's already picked up → empty → not added
+ // Block 1 tries SharedH but it's already picked up -> empty -> not added
REQUIRE_EQ(1u, NeededBlocks.size());
CHECK_EQ(0u, NeededBlocks[0].BlockIndex);
REQUIRE_EQ(3u, NeededBlocks[0].ChunkIndexes.size());
@@ -1554,8 +1539,7 @@ TEST_CASE("chunkblock.getneeded.dedup_no_double_pickup")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
// SharedH appears in both blocks; should appear in the result exactly once
IoHash SharedH = IoHash::HashBuffer(MemoryView("shared_chunk_nodup", 18));
@@ -1578,16 +1562,16 @@ TEST_CASE("chunkblock.getneeded.dedup_no_double_pickup")
std::vector<ChunkBlockDescription> Blocks = {Block0, Block1};
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+ ChunkBlockAnalyser Analyser(LogRef, Blocks, Options);
- // Map: SharedH→0, H0→1, H1→2, H2→3, H3→4
+ // Map: SharedH->0, H0->1, H1->2, H2->3, H3->4
// Only SharedH (index 0) needed; no other chunks
auto HashMap = MakeHashMap(Blocks);
auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex == 0; });
- // Block 0: SharedH needed, H0 not needed → slack=100
- // Block 1: SharedH needed, H1/H2/H3 not needed → slack=300
- // Block 0 processed first → picks up SharedH; Block 1 skips it
+ // Block 0: SharedH needed, H0 not needed -> slack=100
+ // Block 1: SharedH needed, H1/H2/H3 not needed -> slack=300
+ // Block 0 processed first -> picks up SharedH; Block 1 skips it
// Count total occurrences of SharedH across all NeededBlocks
uint32_t SharedOccurrences = 0;
@@ -1609,13 +1593,12 @@ TEST_CASE("chunkblock.getneeded.skips_unrequested_chunks")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
- // Block has 4 chunks but only 2 appear in the hash map → ChunkIndexes has exactly those 2
+ // Block has 4 chunks but only 2 appear in the hash map -> ChunkIndexes has exactly those 2
auto Block = MakeBlockDesc(50, {100, 100, 100, 100});
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
// Only put chunks at positions 0 and 2 in the map
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> HashMap;
@@ -1635,19 +1618,18 @@ TEST_CASE("chunkblock.getneeded.two_blocks_both_contribute")
{
using namespace chunkblock_analyser_testutils;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
// Block 0: all 4 needed (slack=0); block 1: 3 of 4 needed (slack=100)
- // Both blocks contribute chunks → 2 NeededBlocks in result
+ // Both blocks contribute chunks -> 2 NeededBlocks in result
auto Block0 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/0);
auto Block1 = MakeBlockDesc(50, {100, 100, 100, 100}, /*BlockSeed=*/200);
std::vector<ChunkBlockDescription> Blocks = {Block0, Block1};
ChunkBlockAnalyser::Options Options;
- ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+ ChunkBlockAnalyser Analyser(LogRef, Blocks, Options);
- // HashMap: Block0 hashes → indices 0-3, Block1 hashes → indices 4-7
+ // HashMap: Block0 hashes -> indices 0-3, Block1 hashes -> indices 4-7
auto HashMap = MakeHashMap(Blocks);
// Need all Block0 chunks (0-3) and Block1 chunks 0-2 (indices 4-6); not chunk index 7 (Block1 chunk 3)
auto NeededBlocks = Analyser.GetNeeded(HashMap, [](uint32_t ChunkIndex) { return ChunkIndex <= 6; });
@@ -1666,15 +1648,14 @@ TEST_CASE("chunkblock.calc.off_mode")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
- // HeaderSize > 0, chunks size matches → CanDoPartialBlockDownload = true
+ // HeaderSize > 0, chunks size matches -> CanDoPartialBlockDownload = true
// But mode Off forces full block regardless
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
std::vector<Mode> Modes = {Mode::Off};
@@ -1691,17 +1672,16 @@ TEST_CASE("chunkblock.calc.exact_mode")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
- // Need chunks 0 and 2 → 2 non-contiguous ranges; Exact mode passes them straight through
+ // Need chunks 0 and 2 -> 2 non-contiguous ranges; Exact mode passes them straight through
std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
std::vector<Mode> Modes = {Mode::Exact};
@@ -1728,18 +1708,17 @@ TEST_CASE("chunkblock.calc.singlerange_mode")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
- // Default HostLatencySec=-1 → OptimizeRanges not called after SingleRange collapse
+ // Default HostLatencySec=-1 -> OptimizeRanges not called after SingleRange collapse
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
- // Need chunks 0 and 2 → 2 ranges that get collapsed to 1
+ // Need chunks 0 and 2 -> 2 ranges that get collapsed to 1
std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
std::vector<Mode> Modes = {Mode::SingleRange};
@@ -1761,16 +1740,15 @@ TEST_CASE("chunkblock.calc.multirange_mode")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
- // Low latency: RequestTimeAsBytes=100 << slack → OptimizeRanges preserves ranges
+ // Low latency: RequestTimeAsBytes=100 << slack -> OptimizeRanges preserves ranges
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
Options.HostLatencySec = 0.001;
Options.HostSpeedBytesPerSec = 100000;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
@@ -1792,17 +1770,16 @@ TEST_CASE("chunkblock.calc.multirangehighspeed_mode")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
- // Block slack ≈ 714 bytes (TotalBlockSize≈1114, RangeTotalSize=400 for chunks 0+2)
- // RequestTimeAsBytes = 400000 * 1 * 0.001 = 400 < 714 → ranges preserved
+ // Block slack ~= 714 bytes (TotalBlockSize~=1114, RangeTotalSize=400 for chunks 0+2)
+ // RequestTimeAsBytes = 400000 * 1 * 0.001 = 400 < 714 -> ranges preserved
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
Options.HostHighSpeedLatencySec = 0.001;
Options.HostHighSpeedBytesPerSec = 400000;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
@@ -1824,17 +1801,16 @@ TEST_CASE("chunkblock.calc.all_chunks_needed_full_block")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
Options.HostLatencySec = 0.001;
Options.HostSpeedBytesPerSec = 100000;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
- // All 4 chunks needed → short-circuit to full block regardless of mode
+ // All 4 chunks needed -> short-circuit to full block regardless of mode
std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 1, 2, 3}}};
std::vector<Mode> Modes = {Mode::Exact};
@@ -1850,14 +1826,13 @@ TEST_CASE("chunkblock.calc.headersize_zero_forces_full_block")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
- // HeaderSize=0 → CanDoPartialBlockDownload=false → full block even in Exact mode
+ // HeaderSize=0 -> CanDoPartialBlockDownload=false -> full block even in Exact mode
auto Block = MakeBlockDesc(0, {100, 200, 300, 400});
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2}}};
std::vector<Mode> Modes = {Mode::Exact};
@@ -1874,25 +1849,24 @@ TEST_CASE("chunkblock.calc.low_maxrangecountperrequest")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
- // 5 chunks of 100 bytes each; need chunks 0, 2, 4 → 3 non-contiguous ranges
- // With MaxRangeCountPerRequest=1 and high latency, cost model merges aggressively → full block
+ // 5 chunks of 100 bytes each; need chunks 0, 2, 4 -> 3 non-contiguous ranges
+ // With MaxRangeCountPerRequest=1 and high latency, cost model merges aggressively -> full block
auto Block = MakeBlockDesc(10, {100, 100, 100, 100, 100});
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
Options.HostLatencySec = 0.1;
Options.HostSpeedBytesPerSec = 1000;
Options.HostMaxRangeCountPerRequest = 1;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
std::vector<ChunkBlockAnalyser::NeededBlock> NeededBlocks = {{.BlockIndex = 0, .ChunkIndexes = {0, 2, 4}}};
std::vector<Mode> Modes = {Mode::MultiRange};
auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
- // Cost model drives merging: 3 requests × 1000 × 0.1 = 300 > slack ≈ 210+headersize
+ // Cost model drives merging: 3 requests x 1000 x 0.1 = 300 > slack ~= 210+headersize
// After merges converges to full block
REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
CHECK_EQ(0u, Result.FullBlockIndexes[0]);
@@ -1904,14 +1878,13 @@ TEST_CASE("chunkblock.calc.no_latency_skips_optimize")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
auto Block = MakeBlockDesc(50, {100, 200, 300, 400});
- // Default HostLatencySec=-1 → OptimizeRanges not called; raw GetBlockRanges result used
+ // Default HostLatencySec=-1 -> OptimizeRanges not called; raw GetBlockRanges result used
ChunkBlockAnalyser::Options Options;
Options.IsQuiet = true;
- ChunkBlockAnalyser Analyser(*LogOutput, std::span<const ChunkBlockDescription>(&Block, 1), Options);
+ ChunkBlockAnalyser Analyser(LogRef, std::span<const ChunkBlockDescription>(&Block, 1), Options);
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
@@ -1920,7 +1893,7 @@ TEST_CASE("chunkblock.calc.no_latency_skips_optimize")
auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
- // No optimize pass → exact ranges from GetBlockRanges
+ // No optimize pass -> exact ranges from GetBlockRanges
CHECK(Result.FullBlockIndexes.empty());
REQUIRE_EQ(2u, Result.BlockRanges.size());
CHECK_EQ(ChunkStartOffset, Result.BlockRanges[0].RangeStart);
@@ -1934,8 +1907,7 @@ TEST_CASE("chunkblock.calc.multiple_blocks_different_modes")
using namespace chunkblock_analyser_testutils;
using Mode = ChunkBlockAnalyser::EPartialBlockDownloadMode;
- LoggerRef LogRef = Log();
- std::unique_ptr<OperationLogOutput> LogOutput(CreateStandardLogOutput(LogRef));
+ LoggerRef LogRef = Log();
// 3 blocks with different modes: Off, Exact, MultiRange
auto Block0 = MakeBlockDesc(50, {100, 200, 300, 400}, /*BlockSeed=*/0);
@@ -1948,7 +1920,7 @@ TEST_CASE("chunkblock.calc.multiple_blocks_different_modes")
Options.HostSpeedBytesPerSec = 100000;
std::vector<ChunkBlockDescription> Blocks = {Block0, Block1, Block2};
- ChunkBlockAnalyser Analyser(*LogOutput, Blocks, Options);
+ ChunkBlockAnalyser Analyser(LogRef, Blocks, Options);
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + 50;
@@ -1961,11 +1933,11 @@ TEST_CASE("chunkblock.calc.multiple_blocks_different_modes")
auto Result = Analyser.CalculatePartialBlockDownloads(NeededBlocks, Modes);
- // Block 0: Off → FullBlockIndexes
+ // Block 0: Off -> FullBlockIndexes
REQUIRE_EQ(1u, Result.FullBlockIndexes.size());
CHECK_EQ(0u, Result.FullBlockIndexes[0]);
- // Block 1: Exact → 2 ranges; Block 2: MultiRange (low latency) → 2 ranges
+ // Block 1: Exact -> 2 ranges; Block 2: MultiRange (low latency) -> 2 ranges
// Total: 4 ranges
REQUIRE_EQ(4u, Result.BlockRanges.size());
@@ -2058,7 +2030,7 @@ TEST_CASE("chunkblock.getblockranges.non_contiguous")
{
using namespace chunkblock_analyser_testutils;
- // Chunks 0 and 2 needed, chunk 1 skipped → two separate ranges
+ // Chunks 0 and 2 needed, chunk 1 skipped -> two separate ranges
auto Block = MakeBlockDesc(50, {100, 200, 300});
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
@@ -2082,7 +2054,7 @@ TEST_CASE("chunkblock.getblockranges.contiguous_run")
{
using namespace chunkblock_analyser_testutils;
- // Chunks 1, 2, 3 needed (consecutive) → one merged range
+ // Chunks 1, 2, 3 needed (consecutive) -> one merged range
auto Block = MakeBlockDesc(50, {50, 100, 150, 200, 250});
uint64_t ChunkStartOffset = CompressedBuffer::GetHeaderSizeForNoneEncoder() + Block.HeaderSize;
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
index 0d2eded58..2d29e5efd 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h
@@ -24,7 +24,7 @@ namespace zen {
class CloneQueryInterface;
-class OperationLogOutput;
+class ProgressBase;
class BuildStorageBase;
class HttpClient;
class ParallelWork;
@@ -128,7 +128,6 @@ public:
std::uint64_t PreferredMultipartChunkSize = 32u * 1024u * 1024u;
EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::Mixed;
bool WipeTargetFolder = false;
- bool PrimeCacheOnly = false;
bool EnableOtherDownloadsScavenging = true;
bool EnableTargetFolderScavenging = true;
bool ValidateCompletedSequences = true;
@@ -137,7 +136,8 @@ public:
bool PopulateCache = true;
};
- BuildsOperationUpdateFolder(OperationLogOutput& OperationLogOutput,
+ BuildsOperationUpdateFolder(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -209,6 +209,45 @@ private:
uint64_t ElapsedTimeMs = 0;
};
+ struct LooseChunkHashWorkData
+ {
+ std::vector<const ChunkedContentLookup::ChunkSequenceLocation*> ChunkTargetPtrs;
+ uint32_t RemoteChunkIndex = (uint32_t)-1;
+ };
+
+ struct FinalizeTarget
+ {
+ IoHash RawHash;
+ uint32_t RemotePathIndex;
+ };
+
+ struct LocalPathCategorization
+ {
+ std::vector<uint32_t> FilesToCache;
+ std::vector<uint32_t> RemoveLocalPathIndexes;
+ tsl::robin_map<uint32_t, uint32_t> RemotePathIndexToLocalPathIndex;
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher> SequenceHashToLocalPathIndex;
+ uint64_t MatchCount = 0;
+ uint64_t PathMismatchCount = 0;
+ uint64_t HashMismatchCount = 0;
+ uint64_t SkippedCount = 0;
+ uint64_t DeleteCount = 0;
+ };
+
+ struct WriteChunksContext
+ {
+ ParallelWork& Work;
+ BufferedWriteFileCache& WriteCache;
+ std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters;
+ std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags;
+ std::atomic<uint64_t>& WritePartsComplete;
+ uint64_t TotalPartWriteCount;
+ uint64_t TotalRequestCount;
+ const BlobsExistsResult& ExistsResult;
+ FilteredRate& FilteredDownloadedBytesPerSecond;
+ FilteredRate& FilteredWrittenBytesPerSecond;
+ };
+
void ScanCacheFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedChunkHashesFound,
tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedSequenceHashesFound);
void ScanTempBlocksFolder(tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutCachedBlocksFound);
@@ -261,12 +300,16 @@ private:
void DownloadBuildBlob(uint32_t RemoteChunkIndex,
const BlobsExistsResult& ExistsResult,
ParallelWork& Work,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& Payload)>&& OnDownloaded);
- void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges,
- size_t BlockRangeIndex,
- size_t BlockRangeCount,
- const BlobsExistsResult& ExistsResult,
+ void DownloadPartialBlock(std::span<const ChunkBlockAnalyser::BlockRangeDescriptor> BlockRanges,
+ size_t BlockRangeIndex,
+ size_t BlockRangeCount,
+ const BlobsExistsResult& ExistsResult,
+ uint64_t TotalRequestCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond,
std::function<void(IoBuffer&& InMemoryBuffer,
const std::filesystem::path& OnDiskPath,
size_t BlockRangeStartIndex,
@@ -323,8 +366,8 @@ private:
std::span<std::atomic<bool>> RemoteChunkIndexNeedsCopyFromSourceFlags,
BufferedWriteFileCache& WriteCache);
- void AsyncWriteDownloadedChunk(const std::filesystem::path& ZenFolderPath,
- uint32_t RemoteChunkIndex,
+ void AsyncWriteDownloadedChunk(uint32_t RemoteChunkIndex,
+ const BlobsExistsResult& ExistsResult,
std::vector<const ChunkedContentLookup::ChunkSequenceLocation*>&& ChunkTargetPtrs,
BufferedWriteFileCache& WriteCache,
ParallelWork& Work,
@@ -332,8 +375,7 @@ private:
std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters,
std::atomic<uint64_t>& WritePartsComplete,
const uint64_t TotalPartWriteCount,
- FilteredRate& FilteredWrittenBytesPerSecond,
- bool EnableBacklog);
+ FilteredRate& FilteredWrittenBytesPerSecond);
void VerifyAndCompleteChunkSequencesAsync(std::span<const uint32_t> RemoteSequenceIndexes, ParallelWork& Work);
bool CompleteSequenceChunk(uint32_t RemoteSequenceIndex, std::span<std::atomic<uint32_t>> SequenceIndexChunksLeftToWriteCounters);
@@ -343,7 +385,112 @@ private:
void FinalizeChunkSequences(std::span<const uint32_t> RemoteSequenceIndexes);
void VerifySequence(uint32_t RemoteSequenceIndex);
- OperationLogOutput& m_LogOutput;
+ void InitializeSequenceCounters(std::vector<std::atomic<uint32_t>>& OutSequenceCounters,
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& OutSequencesLeftToFind,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedChunkHashesFound,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedSequenceHashesFound);
+
+ void MatchScavengedSequencesToRemote(std::span<const ChunkedFolderContent> Contents,
+ std::span<const ChunkedContentLookup> Lookups,
+ std::span<const std::filesystem::path> Paths,
+ tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& InOutSequencesLeftToFind,
+ std::vector<std::atomic<uint32_t>>& InOutSequenceCounters,
+ std::vector<ScavengedSequenceCopyOperation>& OutCopyOperations,
+ uint64_t& OutScavengedPathsCount);
+
+ uint64_t CalculateBytesToWriteAndFlagNeededChunks(std::span<const std::atomic<uint32_t>> SequenceCounters,
+ const std::vector<bool>& NeedsCopyFromLocalFileFlags,
+ std::span<std::atomic<bool>> OutNeedsCopyFromSourceFlags);
+
+ void ClassifyCachedAndFetchBlocks(std::span<const ChunkBlockAnalyser::NeededBlock> NeededBlocks,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& CachedBlocksFound,
+ uint64_t& TotalPartWriteCount,
+ std::vector<uint32_t>& OutCachedChunkBlockIndexes,
+ std::vector<uint32_t>& OutFetchBlockIndexes);
+
+ std::vector<uint32_t> DetermineNeededLooseChunkIndexes(std::span<const std::atomic<uint32_t>> SequenceCounters,
+ const std::vector<bool>& NeedsCopyFromLocalFileFlags,
+ std::span<std::atomic<bool>> NeedsCopyFromSourceFlags);
+
+ BlobsExistsResult QueryBlobCacheExists(std::span<const uint32_t> NeededLooseChunkIndexes, std::span<const uint32_t> FetchBlockIndexes);
+
+ std::vector<ChunkBlockAnalyser::EPartialBlockDownloadMode> DeterminePartialDownloadModes(const BlobsExistsResult& ExistsResult);
+
+ std::vector<LooseChunkHashWorkData> BuildLooseChunkHashWorks(std::span<const uint32_t> NeededLooseChunkIndexes,
+ std::span<const std::atomic<uint32_t>> SequenceCounters);
+
+ void VerifyWriteChunksComplete(std::span<const std::atomic<uint32_t>> SequenceCounters,
+ uint64_t BytesToWrite,
+ uint64_t BytesToValidate);
+
+ std::vector<FinalizeTarget> BuildSortedFinalizeTargets();
+
+ void ScanScavengeSources(std::span<const ScavengeSource> Sources,
+ std::vector<ChunkedFolderContent>& OutContents,
+ std::vector<ChunkedContentLookup>& OutLookups,
+ std::vector<std::filesystem::path>& OutPaths);
+
+ LocalPathCategorization CategorizeLocalPaths(const tsl::robin_map<std::string, uint32_t>& RemotePathToRemoteIndex);
+
+ void ScheduleLocalFileCaching(std::span<const uint32_t> FilesToCache,
+ std::atomic<uint64_t>& OutCachedCount,
+ std::atomic<uint64_t>& OutCachedByteCount);
+
+ void ScheduleScavengedSequenceWrites(WriteChunksContext& Context,
+ std::span<const ScavengedSequenceCopyOperation> CopyOperations,
+ const std::vector<ChunkedFolderContent>& ScavengedContents,
+ const std::vector<std::filesystem::path>& ScavengedPaths);
+
+ void ScheduleLooseChunkWrites(WriteChunksContext& Context, std::vector<LooseChunkHashWorkData>& LooseChunkHashWorks);
+
+ void ScheduleLocalChunkCopies(WriteChunksContext& Context,
+ std::span<const CopyChunkData> CopyChunkDatas,
+ CloneQueryInterface* CloneQuery,
+ const std::vector<ChunkedFolderContent>& ScavengedContents,
+ const std::vector<ChunkedContentLookup>& ScavengedLookups,
+ const std::vector<std::filesystem::path>& ScavengedPaths);
+
+ void ScheduleCachedBlockWrites(WriteChunksContext& Context, std::span<const uint32_t> CachedBlockIndexes);
+
+ void SchedulePartialBlockDownloads(WriteChunksContext& Context, const ChunkBlockAnalyser::BlockResult& PartialBlocks);
+
+ void WritePartialBlockToCache(WriteChunksContext& Context,
+ size_t BlockRangeStartIndex,
+ IoBuffer BlockPartialBuffer,
+ const std::filesystem::path& BlockChunkPath,
+ std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengths,
+ const ChunkBlockAnalyser::BlockResult& PartialBlocks);
+
+ void ScheduleFullBlockDownloads(WriteChunksContext& Context, std::span<const uint32_t> FullBlockIndexes);
+
+ void WriteFullBlockToCache(WriteChunksContext& Context,
+ uint32_t BlockIndex,
+ IoBuffer BlockBuffer,
+ const std::filesystem::path& BlockChunkPath);
+
+ void ScheduleLocalFileRemovals(ParallelWork& Work,
+ std::span<const uint32_t> RemoveLocalPathIndexes,
+ std::atomic<uint64_t>& DeletedCount);
+
+ void ScheduleTargetFinalization(ParallelWork& Work,
+ std::span<const FinalizeTarget> Targets,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex,
+ const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex,
+ FolderContent& OutLocalFolderState,
+ std::atomic<uint64_t>& TargetsComplete);
+
+ void FinalizeTargetGroup(size_t BaseOffset,
+ size_t Count,
+ std::span<const FinalizeTarget> Targets,
+ const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& SequenceHashToLocalPathIndex,
+ const tsl::robin_map<uint32_t, uint32_t>& RemotePathIndexToLocalPathIndex,
+ FolderContent& OutLocalFolderState,
+ std::atomic<uint64_t>& TargetsComplete);
+
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ ProgressBase& m_Progress;
StorageInstance& m_Storage;
std::atomic<bool>& m_AbortFlag;
std::atomic<bool>& m_PauseFlag;
@@ -492,7 +639,8 @@ public:
bool PopulateCache = true;
};
- BuildsOperationUploadFolder(OperationLogOutput& OperationLogOutput,
+ BuildsOperationUploadFolder(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -569,6 +717,28 @@ private:
GenerateBlocksStatistics& GenerateBlocksStats,
UploadStatistics& UploadStats);
+ struct GenerateBuildBlocksContext
+ {
+ ParallelWork& Work;
+ WorkerThreadPool& GenerateBlobsPool;
+ WorkerThreadPool& UploadBlocksPool;
+ FilteredRate& FilteredGeneratedBytesPerSecond;
+ FilteredRate& FilteredUploadedBytesPerSecond;
+ std::atomic<uint64_t>& QueuedPendingBlocksForUpload;
+ RwLock& Lock;
+ GeneratedBlocks& OutBlocks;
+ GenerateBlocksStatistics& GenerateBlocksStats;
+ UploadStatistics& UploadStats;
+ size_t NewBlockCount;
+ };
+
+ void ScheduleBlockGeneration(GenerateBuildBlocksContext& Context,
+ const ChunkedFolderContent& Content,
+ const ChunkedContentLookup& Lookup,
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks);
+
+ void UploadGeneratedBlock(GenerateBuildBlocksContext& Context, size_t BlockIndex, CompressedBuffer Payload);
+
std::vector<uint32_t> CalculateAbsoluteChunkOrders(const std::span<const IoHash> LocalChunkHashes,
const std::span<const uint32_t> LocalChunkOrder,
const tsl::robin_map<IoHash, uint32_t, IoHash::Hasher>& ChunkHashToLocalChunkIndex,
@@ -609,6 +779,58 @@ private:
uint32_t PartStepOffset,
uint32_t StepCount);
+ ChunkedFolderContent ScanPartContent(const UploadPart& Part,
+ ChunkingController& ChunkController,
+ ChunkingCache& ChunkCache,
+ ChunkingStatistics& ChunkingStats);
+
+ void ConsumePrepareBuildResult();
+
+ void ClassifyChunksByBlockEligibility(const ChunkedFolderContent& LocalContent,
+ std::vector<uint32_t>& OutLooseChunkIndexes,
+ std::vector<uint32_t>& OutNewBlockChunkIndexes,
+ std::vector<size_t>& OutReuseBlockIndexes,
+ LooseChunksStatistics& LooseChunksStats,
+ FindBlocksStatistics& FindBlocksStats,
+ ReuseBlocksStatistics& ReuseBlocksStats);
+
+ struct BuiltPartManifest
+ {
+ CbObject PartManifest;
+ std::vector<ChunkBlockDescription> AllChunkBlockDescriptions;
+ std::vector<IoHash> AllChunkBlockHashes;
+ };
+
+ BuiltPartManifest BuildPartManifestObject(const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ ChunkingController& ChunkController,
+ std::span<const size_t> ReuseBlockIndexes,
+ const GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes);
+
+ void UploadAttachmentBatch(std::span<IoHash> RawHashes,
+ std::vector<IoHash>& OutUnknownChunks,
+ const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks,
+ GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes,
+ UploadStatistics& UploadStats,
+ LooseChunksStatistics& LooseChunksStats);
+
+ void FinalizeBuildPartWithRetries(const UploadPart& Part,
+ const IoHash& PartHash,
+ std::vector<IoHash>& InOutUnknownChunks,
+ const ChunkedFolderContent& LocalContent,
+ const ChunkedContentLookup& LocalLookup,
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks,
+ GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes,
+ UploadStatistics& UploadStats,
+ LooseChunksStatistics& LooseChunksStats);
+
+ void UploadMissingBlockMetadata(GeneratedBlocks& NewBlocks, UploadStatistics& UploadStats);
+
void UploadPartBlobs(const ChunkedFolderContent& Content,
const ChunkedContentLookup& Lookup,
std::span<IoHash> RawHashes,
@@ -620,18 +842,72 @@ private:
LooseChunksStatistics& TempLooseChunksStats,
std::vector<IoHash>& OutUnknownChunks);
+ struct UploadPartClassification
+ {
+ std::vector<size_t> BlockIndexes;
+ std::vector<uint32_t> LooseChunkOrderIndexes;
+ uint64_t TotalBlocksSize = 0;
+ uint64_t TotalLooseChunksSize = 0;
+ };
+
+ UploadPartClassification ClassifyUploadRawHashes(std::span<IoHash> RawHashes,
+ const ChunkedFolderContent& Content,
+ const ChunkedContentLookup& Lookup,
+ const GeneratedBlocks& NewBlocks,
+ std::span<const uint32_t> LooseChunkIndexes,
+ std::vector<IoHash>& OutUnknownChunks);
+
+ struct UploadPartBlobsContext
+ {
+ ParallelWork& Work;
+ WorkerThreadPool& ReadChunkPool;
+ WorkerThreadPool& UploadChunkPool;
+ FilteredRate& FilteredGenerateBlockBytesPerSecond;
+ FilteredRate& FilteredCompressedBytesPerSecond;
+ FilteredRate& FilteredUploadedBytesPerSecond;
+ std::atomic<size_t>& UploadedBlockSize;
+ std::atomic<size_t>& UploadedBlockCount;
+ std::atomic<size_t>& UploadedRawChunkSize;
+ std::atomic<size_t>& UploadedCompressedChunkSize;
+ std::atomic<uint32_t>& UploadedChunkCount;
+ std::atomic<uint64_t>& GeneratedBlockCount;
+ std::atomic<uint64_t>& GeneratedBlockByteCount;
+ std::atomic<uint64_t>& QueuedPendingInMemoryBlocksForUpload;
+ size_t UploadBlockCount;
+ uint32_t UploadChunkCount;
+ uint64_t LargeAttachmentSize;
+ GeneratedBlocks& NewBlocks;
+ const ChunkedFolderContent& Content;
+ const ChunkedContentLookup& Lookup;
+ const std::vector<std::vector<uint32_t>>& NewBlockChunks;
+ std::span<const uint32_t> LooseChunkIndexes;
+ UploadStatistics& TempUploadStats;
+ LooseChunksStatistics& TempLooseChunksStats;
+ };
+
+ void ScheduleBlockGenerationAndUpload(UploadPartBlobsContext& Context, std::span<const size_t> BlockIndexes);
+
+ void ScheduleLooseChunkCompressionAndUpload(UploadPartBlobsContext& Context, std::span<const uint32_t> LooseChunkOrderIndexes);
+
+ void UploadBlockPayload(UploadPartBlobsContext& Context, size_t BlockIndex, const IoHash& BlockHash, CompositeBuffer Payload);
+
+ void UploadLooseChunkPayload(UploadPartBlobsContext& Context, const IoHash& RawHash, uint64_t RawSize, CompositeBuffer Payload);
+
CompositeBuffer CompressChunk(const ChunkedFolderContent& Content,
const ChunkedContentLookup& Lookup,
uint32_t ChunkIndex,
LooseChunksStatistics& TempLooseChunksStats);
- OperationLogOutput& m_LogOutput;
- StorageInstance& m_Storage;
- std::atomic<bool>& m_AbortFlag;
- std::atomic<bool>& m_PauseFlag;
- WorkerThreadPool& m_IOWorkerPool;
- WorkerThreadPool& m_NetworkPool;
- const Oid m_BuildId;
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ ProgressBase& m_Progress;
+ StorageInstance& m_Storage;
+ std::atomic<bool>& m_AbortFlag;
+ std::atomic<bool>& m_PauseFlag;
+ WorkerThreadPool& m_IOWorkerPool;
+ WorkerThreadPool& m_NetworkPool;
+ const Oid m_BuildId;
const std::filesystem::path m_Path;
const bool m_CreateBuild; // ?? Member?
@@ -665,7 +941,8 @@ public:
bool IsQuiet = false;
bool IsVerbose = false;
};
- BuildsOperationValidateBuildPart(OperationLogOutput& OperationLogOutput,
+ BuildsOperationValidateBuildPart(LoggerRef Log,
+ ProgressBase& Progress,
BuildStorageBase& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -682,21 +959,61 @@ public:
DownloadStatistics m_DownloadStats;
private:
+ enum class TaskSteps : uint32_t
+ {
+ FetchBuild,
+ FetchBuildPart,
+ ValidateBlobs,
+ Cleanup,
+ StepCount
+ };
+
ChunkBlockDescription ValidateChunkBlock(IoBuffer&& Payload,
const IoHash& BlobHash,
uint64_t& OutCompressedSize,
uint64_t& OutDecompressedSize);
- OperationLogOutput& m_LogOutput;
- BuildStorageBase& m_Storage;
- std::atomic<bool>& m_AbortFlag;
- std::atomic<bool>& m_PauseFlag;
- WorkerThreadPool& m_IOWorkerPool;
- WorkerThreadPool& m_NetworkPool;
- const Oid m_BuildId;
- Oid m_BuildPartId;
- const std::string m_BuildPartName;
- const Options m_Options;
+ struct ValidateBlobsContext
+ {
+ ParallelWork& Work;
+ uint64_t AttachmentsToVerifyCount;
+ FilteredRate& FilteredDownloadedBytesPerSecond;
+ FilteredRate& FilteredVerifiedBytesPerSecond;
+ };
+
+ struct ResolvedBuildPart
+ {
+ std::vector<IoHash> ChunkAttachments;
+ std::vector<IoHash> BlockAttachments;
+ uint64_t PreferredMultipartChunkSize = 0;
+ };
+
+ ResolvedBuildPart ResolveBuildPart();
+
+ void ScheduleChunkAttachmentValidation(ValidateBlobsContext& Context,
+ std::span<const IoHash> ChunkAttachments,
+ const std::filesystem::path& TempFolder,
+ uint64_t PreferredMultipartChunkSize);
+
+ void ScheduleBlockAttachmentValidation(ValidateBlobsContext& Context, std::span<const IoHash> BlockAttachments);
+
+ void ValidateDownloadedChunk(ValidateBlobsContext& Context, const IoHash& ChunkHash, IoBuffer Payload);
+
+ void ValidateDownloadedBlock(ValidateBlobsContext& Context, const IoHash& BlockAttachment, IoBuffer Payload);
+
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ ProgressBase& m_Progress;
+ BuildStorageBase& m_Storage;
+ std::atomic<bool>& m_AbortFlag;
+ std::atomic<bool>& m_PauseFlag;
+ WorkerThreadPool& m_IOWorkerPool;
+ WorkerThreadPool& m_NetworkPool;
+ const Oid m_BuildId;
+ Oid m_BuildPartId;
+ const std::string m_BuildPartName;
+ const Options m_Options;
};
class BuildsOperationPrimeCache
@@ -712,7 +1029,8 @@ public:
bool ForceUpload = false;
};
- BuildsOperationPrimeCache(OperationLogOutput& OperationLogOutput,
+ BuildsOperationPrimeCache(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -727,7 +1045,33 @@ public:
DownloadStatistics m_DownloadStats;
private:
- OperationLogOutput& m_LogOutput;
+ LoggerRef Log() { return m_Log; }
+
+ void CollectReferencedBlobs(tsl::robin_set<IoHash, IoHash::Hasher>& OutBuildBlobs,
+ tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& OutLooseChunkRawSizes);
+
+ std::vector<IoHash> FilterAlreadyCachedBlobs(const tsl::robin_set<IoHash, IoHash::Hasher>& BuildBlobs);
+
+ void ScheduleBlobDownloads(std::span<const IoHash> BlobsToDownload,
+ const tsl::robin_map<IoHash, uint64_t, IoHash::Hasher>& LooseChunkRawSizes,
+ std::atomic<uint64_t>& MultipartAttachmentCount,
+ std::atomic<size_t>& CompletedDownloadCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond);
+
+ void DownloadLargeBlobForCache(ParallelWork& Work,
+ const IoHash& BlobHash,
+ size_t BlobCount,
+ std::atomic<size_t>& CompletedDownloadCount,
+ std::atomic<uint64_t>& MultipartAttachmentCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond);
+
+ void DownloadSingleBlobForCache(const IoHash& BlobHash,
+ size_t BlobCount,
+ std::atomic<size_t>& CompletedDownloadCount,
+ FilteredRate& FilteredDownloadedBytesPerSecond);
+
+ LoggerRef m_Log;
+ ProgressBase& m_Progress;
StorageInstance& m_Storage;
std::atomic<bool>& m_AbortFlag;
std::atomic<bool>& m_PauseFlag;
@@ -755,7 +1099,7 @@ std::vector<std::pair<Oid, std::string>> ResolveBuildPartNames(CbObjectView
struct BuildManifest;
-ChunkedFolderContent GetRemoteContent(OperationLogOutput& Output,
+ChunkedFolderContent GetRemoteContent(LoggerRef InLog,
StorageInstance& Storage,
const Oid& BuildId,
const std::vector<std::pair<Oid, std::string>>& BuildParts,
diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h
index 7306188ca..c55c930bc 100644
--- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h
+++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h
@@ -8,7 +8,6 @@
namespace zen {
-class OperationLogOutput;
class BuildStorageBase;
class BuildStorageCache;
@@ -38,7 +37,7 @@ enum class ZenCacheResolveMode
All
};
-BuildStorageResolveResult ResolveBuildStorage(OperationLogOutput& Output,
+BuildStorageResolveResult ResolveBuildStorage(LoggerRef InLog,
const HttpClientSettings& ClientSettings,
std::string_view Host,
std::string_view OverrideHost,
@@ -46,7 +45,7 @@ BuildStorageResolveResult ResolveBuildStorage(OperationLogOutput& Output,
ZenCacheResolveMode ZenResolveMode,
bool Verbose);
-std::vector<ChunkBlockDescription> GetBlockDescriptions(OperationLogOutput& Output,
+std::vector<ChunkBlockDescription> GetBlockDescriptions(LoggerRef InLog,
BuildStorageBase& Storage,
BuildStorageCache* OptionalCacheStorage,
const Oid& BuildId,
diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h
index e3a5f6539..73d037542 100644
--- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h
+++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/iohash.h>
+#include <zencore/logbase.h>
#include <zencore/compactbinary.h>
#include <zencore/compress.h>
@@ -64,9 +65,7 @@ struct ReuseBlocksStatistics
}
};
-class OperationLogOutput;
-
-std::vector<size_t> FindReuseBlocks(OperationLogOutput& Output,
+std::vector<size_t> FindReuseBlocks(LoggerRef InLog,
const uint8_t BlockReuseMinPercentLimit,
const bool IsVerbose,
ReuseBlocksStatistics& Stats,
@@ -91,7 +90,7 @@ public:
uint64_t MaxRangesPerBlock = 1024u;
};
- ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options);
+ ChunkBlockAnalyser(LoggerRef Log, std::span<const ChunkBlockDescription> BlockDescriptions, const Options& Options);
struct BlockRangeDescriptor
{
@@ -130,7 +129,9 @@ public:
std::span<const EPartialBlockDownloadMode> BlockPartialDownloadModes);
private:
- OperationLogOutput& m_LogOutput;
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
const std::span<const ChunkBlockDescription> m_BlockDescriptions;
const Options m_Options;
};
diff --git a/src/zenremotestore/include/zenremotestore/operationlogoutput.h b/src/zenremotestore/include/zenremotestore/operationlogoutput.h
deleted file mode 100644
index 32b95f50f..000000000
--- a/src/zenremotestore/include/zenremotestore/operationlogoutput.h
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/fmtutils.h>
-#include <zencore/logbase.h>
-
-namespace zen {
-
-class OperationLogOutput
-{
-public:
- virtual ~OperationLogOutput() {}
- virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) = 0;
-
- virtual void SetLogOperationName(std::string_view Name) = 0;
- virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0;
- virtual uint32_t GetProgressUpdateDelayMS() = 0;
-
- class ProgressBar
- {
- public:
- struct State
- {
- bool operator==(const State&) const = default;
- std::string Task;
- std::string Details;
- uint64_t TotalCount = 0;
- uint64_t RemainingCount = 0;
- enum class EStatus
- {
- Running,
- Aborted,
- Paused
- };
- EStatus Status = EStatus::Running;
-
- static EStatus CalculateStatus(bool IsAborted, bool IsPaused)
- {
- if (IsAborted)
- {
- return EStatus::Aborted;
- }
- if (IsPaused)
- {
- return EStatus::Paused;
- }
- return EStatus::Running;
- }
- };
-
- virtual ~ProgressBar() {}
-
- virtual void UpdateState(const State& NewState, bool DoLinebreak) = 0;
- virtual void Finish() = 0;
- };
-
- virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) = 0;
-};
-
-OperationLogOutput* CreateStandardLogOutput(LoggerRef Log);
-
-#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \
- do \
- { \
- using namespace std::literals; \
- static constinit zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \
- ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \
- (OutputTarget).EmitLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \
- } while (false)
-
-#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Info, fmtstr, ##__VA_ARGS__)
-#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Debug, fmtstr, ##__VA_ARGS__)
-#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Warn, fmtstr, ##__VA_ARGS__)
-
-} // namespace zen
diff --git a/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h b/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h
index a07ede6f6..db5b27d3f 100644
--- a/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h
+++ b/src/zenremotestore/include/zenremotestore/projectstore/projectstoreoperations.h
@@ -20,7 +20,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
class BuildStorageBase;
-class OperationLogOutput;
+class ProgressBase;
struct StorageInstance;
class ProjectStoreOperationOplogState
@@ -34,10 +34,7 @@ public:
std::filesystem::path TempFolderPath;
};
- ProjectStoreOperationOplogState(OperationLogOutput& OperationLogOutput,
- StorageInstance& Storage,
- const Oid& BuildId,
- const Options& Options);
+ ProjectStoreOperationOplogState(LoggerRef Log, StorageInstance& Storage, const Oid& BuildId, const Options& Options);
CbObjectView LoadBuildObject();
CbObjectView LoadBuildPartsObject();
@@ -51,10 +48,12 @@ public:
const Oid& GetBuildPartId();
private:
- OperationLogOutput& m_LogOutput;
- StorageInstance& m_Storage;
- const Oid m_BuildId;
- const Options m_Options;
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ StorageInstance& m_Storage;
+ const Oid m_BuildId;
+ const Options m_Options;
Oid m_BuildPartId = Oid::Zero;
CbObject m_BuildObject;
@@ -79,7 +78,8 @@ public:
bool PopulateCache = true;
};
- ProjectStoreOperationDownloadAttachments(OperationLogOutput& OperationLogOutput,
+ ProjectStoreOperationDownloadAttachments(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -92,12 +92,15 @@ public:
void Execute();
private:
- OperationLogOutput& m_LogOutput;
- StorageInstance& m_Storage;
- std::atomic<bool>& m_AbortFlag;
- std::atomic<bool>& m_PauseFlag;
- WorkerThreadPool& m_IOWorkerPool;
- WorkerThreadPool& m_NetworkPool;
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log;
+ ProgressBase& m_Progress;
+ StorageInstance& m_Storage;
+ std::atomic<bool>& m_AbortFlag;
+ std::atomic<bool>& m_PauseFlag;
+ WorkerThreadPool& m_IOWorkerPool;
+ WorkerThreadPool& m_NetworkPool;
ProjectStoreOperationOplogState& m_State;
const tsl::robin_set<IoHash, IoHash::Hasher> m_AttachmentHashes;
diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h
index 8df892053..b81708341 100644
--- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h
+++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h
@@ -152,7 +152,8 @@ struct RemoteStoreOptions
typedef std::function<CompositeBuffer(const IoHash& AttachmentHash)> TGetAttachmentBufferFunc;
-CbObject BuildContainer(CidStore& ChunkStore,
+CbObject BuildContainer(LoggerRef InLog,
+ CidStore& ChunkStore,
ProjectStore::Project& Project,
ProjectStore::Oplog& Oplog,
WorkerThreadPool& WorkerPool,
@@ -205,7 +206,8 @@ RemoteProjectStore::Result SaveOplogContainer(
const std::function<void(const ChunkedInfo& Chunked)>& OnChunkedAttachment,
JobContext* OptionalContext);
-void SaveOplog(CidStore& ChunkStore,
+void SaveOplog(LoggerRef InLog,
+ CidStore& ChunkStore,
RemoteProjectStore& RemoteStore,
ProjectStore::Project& Project,
ProjectStore::Oplog& Oplog,
@@ -222,6 +224,7 @@ void SaveOplog(CidStore& ChunkStore,
struct LoadOplogContext
{
+ LoggerRef Log;
CidStore& ChunkStore;
RemoteProjectStore& RemoteStore;
BuildStorageCache* OptionalCache = nullptr;
diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp
index a9788cb4e..d610d1fc8 100644
--- a/src/zenremotestore/jupiter/jupitersession.cpp
+++ b/src/zenremotestore/jupiter/jupitersession.cpp
@@ -673,7 +673,7 @@ JupiterSession::PutMultipartBuildBlob(std::string_view Namespace,
size_t RetryPartIndex = PartNameToIndex.at(RetryPartId);
const MultipartUploadResponse::Part& RetryPart = Workload->PartDescription.Parts[RetryPartIndex];
IoBuffer RetryPartPayload =
- Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte - 1);
+ Workload->Transmitter(RetryPart.FirstByte, RetryPart.LastByte - RetryPart.FirstByte);
std::string RetryMultipartUploadResponseRequestString =
fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}/uploadMultipart{}&supportsRedirect={}",
Namespace,
diff --git a/src/zenremotestore/operationlogoutput.cpp b/src/zenremotestore/operationlogoutput.cpp
deleted file mode 100644
index 5ed844c9d..000000000
--- a/src/zenremotestore/operationlogoutput.cpp
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenremotestore/operationlogoutput.h>
-
-#include <zencore/logging.h>
-#include <zencore/logging/logger.h>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <gsl/gsl-lite.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-namespace zen {
-
-class StandardLogOutput;
-
-class StandardLogOutputProgressBar : public OperationLogOutput::ProgressBar
-{
-public:
- StandardLogOutputProgressBar(StandardLogOutput& Output, std::string_view InSubTask) : m_Output(Output), m_SubTask(InSubTask) {}
-
- virtual void UpdateState(const State& NewState, bool DoLinebreak) override;
- virtual void Finish() override;
-
-private:
- StandardLogOutput& m_Output;
- std::string m_SubTask;
- State m_State;
-};
-
-class StandardLogOutput : public OperationLogOutput
-{
-public:
- StandardLogOutput(LoggerRef& Log) : m_Log(Log) {}
- virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override
- {
- if (m_Log.ShouldLog(Point.Level))
- {
- m_Log->Log(Point, Args);
- }
- }
-
- virtual void SetLogOperationName(std::string_view Name) override
- {
- m_LogOperationName = Name;
- ZEN_OPERATION_LOG_INFO(*this, "{}", m_LogOperationName);
- }
- virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override
- {
- [[maybe_unused]] const size_t PercentDone = StepCount > 0u ? gsl::narrow<uint8_t>((100 * StepIndex) / StepCount) : 0u;
- ZEN_OPERATION_LOG_INFO(*this, "{}: {}%", m_LogOperationName, PercentDone);
- }
- virtual uint32_t GetProgressUpdateDelayMS() override { return 2000; }
- virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override
- {
- return new StandardLogOutputProgressBar(*this, InSubTask);
- }
-
-private:
- LoggerRef m_Log;
- std::string m_LogOperationName;
- LoggerRef Log() { return m_Log; }
-};
-
-void
-StandardLogOutputProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
-{
- ZEN_UNUSED(DoLinebreak);
- [[maybe_unused]] const size_t PercentDone =
- NewState.TotalCount > 0u ? gsl::narrow<uint8_t>((100 * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount) : 0u;
- std::string Task = NewState.Task;
- switch (NewState.Status)
- {
- case State::EStatus::Aborted:
- Task = "Aborting";
- break;
- case State::EStatus::Paused:
- Task = "Paused";
- break;
- default:
- break;
- }
- ZEN_OPERATION_LOG_INFO(m_Output, "{}: {}%{}", Task, PercentDone, NewState.Details.empty() ? "" : fmt::format(" {}", NewState.Details));
- m_State = NewState;
-}
-void
-StandardLogOutputProgressBar::Finish()
-{
- if (m_State.RemainingCount > 0)
- {
- State NewState = m_State;
- NewState.RemainingCount = 0;
- NewState.Details = "";
- UpdateState(NewState, /*DoLinebreak*/ true);
- }
-}
-
-OperationLogOutput*
-CreateStandardLogOutput(LoggerRef Log)
-{
- return new StandardLogOutput(Log);
-}
-
-} // namespace zen
diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
index e95d9118c..d7596263b 100644
--- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
@@ -9,7 +9,6 @@
#include <zenremotestore/builds/buildstorageutil.h>
#include <zenremotestore/builds/jupiterbuildstorage.h>
-#include <zenremotestore/operationlogoutput.h>
#include <numeric>
@@ -436,8 +435,6 @@ public:
BuildStorageCache* OptionalCache,
const Oid& CacheBuildId) override
{
- std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log()));
-
ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero);
ZEN_ASSERT(OptionalCache == nullptr || CacheBuildId == m_BuildId);
@@ -447,7 +444,7 @@ public:
try
{
- Result.Blocks = zen::GetBlockDescriptions(*Output,
+ Result.Blocks = zen::GetBlockDescriptions(Log(),
*m_BuildStorage,
OptionalCache,
m_BuildId,
diff --git a/src/zenremotestore/projectstore/projectstoreoperations.cpp b/src/zenremotestore/projectstore/projectstoreoperations.cpp
index 36dc4d868..ba4b74825 100644
--- a/src/zenremotestore/projectstore/projectstoreoperations.cpp
+++ b/src/zenremotestore/projectstore/projectstoreoperations.cpp
@@ -3,13 +3,14 @@
#include <zenremotestore/projectstore/projectstoreoperations.h>
#include <zencore/compactbinaryutil.h>
+#include <zencore/fmtutils.h>
#include <zencore/parallelwork.h>
#include <zencore/scopeguard.h>
#include <zencore/timer.h>
#include <zenremotestore/builds/buildstorageutil.h>
#include <zenremotestore/chunking/chunkedfile.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenremotestore/projectstore/remoteprojectstore.h>
+#include <zenutil/progress.h>
namespace zen {
@@ -17,11 +18,11 @@ using namespace std::literals;
//////////////////////////// ProjectStoreOperationOplogState
-ProjectStoreOperationOplogState::ProjectStoreOperationOplogState(OperationLogOutput& OperationLogOutput,
- StorageInstance& Storage,
- const Oid& BuildId,
- const Options& Options)
-: m_LogOutput(OperationLogOutput)
+ProjectStoreOperationOplogState::ProjectStoreOperationOplogState(LoggerRef Log,
+ StorageInstance& Storage,
+ const Oid& BuildId,
+ const Options& Options)
+: m_Log(Log)
, m_Storage(Storage)
, m_BuildId(BuildId)
, m_Options(Options)
@@ -48,10 +49,7 @@ ProjectStoreOperationOplogState::LoadBuildObject()
{
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Read build {} from locally cached file in {}",
- m_BuildId,
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Read build {} from locally cached file in {}", m_BuildId, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
return m_BuildObject;
}
@@ -61,11 +59,10 @@ ProjectStoreOperationOplogState::LoadBuildObject()
m_BuildObject = m_Storage.BuildStorage->GetBuild(m_BuildId);
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Fetched build {} from {} in {}",
- m_BuildId,
- m_Storage.BuildStorageHttp->GetBaseUri(),
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Fetched build {} from {} in {}",
+ m_BuildId,
+ m_Storage.BuildStorageHttp->GetBaseUri(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
CreateDirectories(CachedBuildObjectPath.parent_path());
TemporaryFile::SafeWriteFile(CachedBuildObjectPath, m_BuildObject.GetBuffer().GetView());
@@ -122,11 +119,10 @@ ProjectStoreOperationOplogState::LoadBuildPartsObject()
{
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Read build part {}/{} from locally cached file in {}",
- m_BuildId,
- BuildPartId,
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Read build part {}/{} from locally cached file in {}",
+ m_BuildId,
+ BuildPartId,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
return m_BuildPartsObject;
}
@@ -136,12 +132,11 @@ ProjectStoreOperationOplogState::LoadBuildPartsObject()
m_BuildPartsObject = m_Storage.BuildStorage->GetBuildPart(m_BuildId, BuildPartId);
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Fetched build part {}/{} from {} in {}",
- m_BuildId,
- BuildPartId,
- m_Storage.BuildStorageHttp->GetBaseUri(),
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Fetched build part {}/{} from {} in {}",
+ m_BuildId,
+ BuildPartId,
+ m_Storage.BuildStorageHttp->GetBaseUri(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
CreateDirectories(CachedBuildPartObjectPath.parent_path());
TemporaryFile::SafeWriteFile(CachedBuildPartObjectPath, m_BuildPartsObject.GetBuffer().GetView());
@@ -168,11 +163,7 @@ ProjectStoreOperationOplogState::LoadOpsSectionObject()
}
else if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Read {}/{}/ops from locally cached file in {}",
- BuildPartId,
- m_BuildId,
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Read {}/{}/ops from locally cached file in {}", BuildPartId, m_BuildId, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
return m_OpsSectionObject;
}
}
@@ -193,11 +184,10 @@ ProjectStoreOperationOplogState::LoadOpsSectionObject()
}
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Decompressed and validated oplog payload {} -> {} in {}",
- NiceBytes(OpsSection.GetSize()),
- NiceBytes(m_OpsSectionObject.GetSize()),
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Decompressed and validated oplog payload {} -> {} in {}",
+ NiceBytes(OpsSection.GetSize()),
+ NiceBytes(m_OpsSectionObject.GetSize()),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
if (m_OpsSectionObject)
{
@@ -226,12 +216,11 @@ ProjectStoreOperationOplogState::LoadArrayFromBuildPart(std::string_view ArrayNa
{
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Read {}/{}/{} from locally cached file in {}",
- BuildPartId,
- m_BuildId,
- ArrayName,
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Read {}/{}/{} from locally cached file in {}",
+ BuildPartId,
+ m_BuildId,
+ ArrayName,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
CbArray Result = CbArray(SharedBuffer(std::move(Payload)));
return Result;
@@ -290,7 +279,8 @@ ProjectStoreOperationOplogState::LoadChunksArray()
//////////////////////////// ProjectStoreOperationDownloadAttachments
-ProjectStoreOperationDownloadAttachments::ProjectStoreOperationDownloadAttachments(OperationLogOutput& OperationLogOutput,
+ProjectStoreOperationDownloadAttachments::ProjectStoreOperationDownloadAttachments(LoggerRef Log,
+ ProgressBase& Progress,
StorageInstance& Storage,
std::atomic<bool>& AbortFlag,
std::atomic<bool>& PauseFlag,
@@ -299,7 +289,8 @@ ProjectStoreOperationDownloadAttachments::ProjectStoreOperationDownloadAttachmen
ProjectStoreOperationOplogState& State,
std::span<const IoHash> AttachmentHashes,
const Options& Options)
-: m_LogOutput(OperationLogOutput)
+: m_Log(Log)
+, m_Progress(Progress)
, m_Storage(Storage)
, m_AbortFlag(AbortFlag)
, m_PauseFlag(PauseFlag)
@@ -325,9 +316,9 @@ ProjectStoreOperationDownloadAttachments::Execute()
};
auto EndProgress =
- MakeGuard([&]() { m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); });
+ MakeGuard([&]() { m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::StepCount, (uint32_t)TaskSteps::StepCount); });
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::ReadAttachmentData, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::ReadAttachmentData, (uint32_t)TaskSteps::StepCount);
Stopwatch Timer;
tsl::robin_map<IoHash, uint64_t, IoHash::Hasher> ChunkSizes;
@@ -415,13 +406,12 @@ ProjectStoreOperationDownloadAttachments::Execute()
FilesToDechunk.size() > 0
? fmt::format("\n{} file{} needs to be dechunked", FilesToDechunk.size(), FilesToDechunk.size() == 1 ? "" : "s")
: "";
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Need to download {} block{} and {} chunk{}{}",
- BlocksToDownload.size(),
- BlocksToDownload.size() == 1 ? "" : "s",
- LooseChunksToDownload.size(),
- LooseChunksToDownload.size() == 1 ? "" : "s",
- DechunkInfo);
+ ZEN_INFO("Need to download {} block{} and {} chunk{}{}",
+ BlocksToDownload.size(),
+ BlocksToDownload.size() == 1 ? "" : "s",
+ LooseChunksToDownload.size(),
+ LooseChunksToDownload.size() == 1 ? "" : "s",
+ DechunkInfo);
}
auto GetBuildBlob = [this](const IoHash& RawHash, const std::filesystem::path& OutputPath) {
@@ -470,18 +460,15 @@ ProjectStoreOperationDownloadAttachments::Execute()
std::filesystem::path TempAttachmentPath = MakeSafeAbsolutePath(m_Options.AttachmentOutputPath) / ".tmp";
CreateDirectories(TempAttachmentPath);
auto _0 = MakeGuard([this, &TempAttachmentPath]() {
- if (true)
+ if (!m_Options.IsQuiet)
{
- if (!m_Options.IsQuiet)
- {
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Cleaning up temporary directory");
- }
- CleanDirectory(TempAttachmentPath, true);
- RemoveDir(TempAttachmentPath);
+ ZEN_INFO("Cleaning up temporary directory");
}
+ CleanDirectory(TempAttachmentPath, true);
+ RemoveDir(TempAttachmentPath);
});
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Download, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Download, (uint32_t)TaskSteps::StepCount);
std::filesystem::path BlocksPath = TempAttachmentPath / "blocks";
CreateDirectories(BlocksPath);
@@ -492,11 +479,9 @@ ProjectStoreOperationDownloadAttachments::Execute()
std::filesystem::path LooseChunksPath = TempAttachmentPath / "loosechunks";
CreateDirectories(LooseChunksPath);
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Downloading"));
- OperationLogOutput::ProgressBar& DownloadProgressBar(*ProgressBarPtr);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Downloading");
- std::atomic<bool> PauseFlag;
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
std::atomic<size_t> LooseChunksCompleted;
std::atomic<size_t> BlocksCompleted;
@@ -511,7 +496,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
if (m_Options.ForceDownload || !IsFile(LooseChunkOutputPath))
{
GetBuildBlob(RawHash, LooseChunkOutputPath);
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Downloaded chunk {}", RawHash);
+ ZEN_DEBUG("Downloaded chunk {}", RawHash);
}
Work.ScheduleWork(m_IOWorkerPool, [&, LooseChunkIndex, LooseChunkOutputPath](std::atomic<bool>&) {
@@ -547,7 +532,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
{
ChunkOutput.Close();
RemoveFile(ChunkOutputPath);
- throw std::runtime_error(fmt::format("Failed to decompress chunk {} to ", RawHash, ChunkOutputPath));
+ throw std::runtime_error(fmt::format("Failed to decompress chunk {} to '{}'", RawHash, ChunkOutputPath));
}
}
else
@@ -555,7 +540,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
TemporaryFile::SafeWriteFile(ChunkOutputPath, CompressedChunk.GetCompressed());
}
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Wrote loose chunk {} to '{}'", RawHash, ChunkOutputPath);
+ ZEN_DEBUG("Wrote loose chunk {} to '{}'", RawHash, ChunkOutputPath);
LooseChunksCompleted++;
});
});
@@ -572,7 +557,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
if (m_Options.ForceDownload || !IsFile(BlockOutputPath))
{
GetBuildBlob(RawHash, BlockOutputPath);
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Downloaded block {}", RawHash);
+ ZEN_DEBUG("Downloaded block {}", RawHash);
}
Work.ScheduleWork(m_IOWorkerPool, [&, BlockIndex, BlockOutputPath](std::atomic<bool>&) {
@@ -607,7 +592,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
ChunkOutput.Close();
RemoveFile(ChunkOutputPath);
throw std::runtime_error(
- fmt::format("Failed to decompress chunk {} to ", ChunkHash, ChunkOutputPath));
+ fmt::format("Failed to decompress chunk {} to '{}'", ChunkHash, ChunkOutputPath));
}
}
else
@@ -615,7 +600,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
TemporaryFile::SafeWriteFile(ChunkOutputPath, CompressedChunk.GetCompressed());
}
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Wrote block chunk {} to '{}'", ChunkHash, ChunkOutputPath);
+ ZEN_DEBUG("Wrote block chunk {} to '{}'", ChunkHash, ChunkOutputPath);
}
if (ChunkedFileRawHashes.contains(ChunkHash))
{
@@ -635,7 +620,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
});
}
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
ZEN_UNUSED(IsAborted, IsPaused, PendingWork);
std::string Details = fmt::format("{}/{} blocks, {}/{} chunks downloaded",
@@ -643,39 +628,37 @@ ProjectStoreOperationDownloadAttachments::Execute()
BlocksToDownload.size(),
LooseChunksCompleted.load(),
LooseChunksToDownload.size());
- DownloadProgressBar.UpdateState({.Task = "Downloading",
- .Details = Details,
- .TotalCount = BlocksToDownload.size() + LooseChunksToDownload.size(),
- .RemainingCount = BlocksToDownload.size() + LooseChunksToDownload.size() -
- (BlocksCompleted.load() + LooseChunksCompleted.load()),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ProgressBar->UpdateState({.Task = "Downloading",
+ .Details = Details,
+ .TotalCount = BlocksToDownload.size() + LooseChunksToDownload.size(),
+ .RemainingCount = BlocksToDownload.size() + LooseChunksToDownload.size() -
+ (BlocksCompleted.load() + LooseChunksCompleted.load()),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
});
- DownloadProgressBar.Finish();
+ ProgressBar->Finish();
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "{} block{} downloaded, {} loose chunk{} downloaded in {}",
- BlocksToDownload.size(),
- BlocksToDownload.size() == 1 ? "" : "s",
- LooseChunksToDownload.size(),
- LooseChunksToDownload.size() == 1 ? "" : "s",
- NiceTimeSpanMs(DownloadTimer.GetElapsedTimeMs()));
+ ZEN_INFO("{} block{} downloaded, {} loose chunk{} downloaded in {}",
+ BlocksToDownload.size(),
+ BlocksToDownload.size() == 1 ? "" : "s",
+ LooseChunksToDownload.size(),
+ LooseChunksToDownload.size() == 1 ? "" : "s",
+ NiceTimeSpanMs(DownloadTimer.GetElapsedTimeMs()));
}
}
if (!ChunkedFileInfos.empty())
{
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::AnalyzeDechunk, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::AnalyzeDechunk, (uint32_t)TaskSteps::StepCount);
std::filesystem::path ChunkedFilesPath = TempAttachmentPath / "chunkedfiles";
CreateDirectories(ChunkedFilesPath);
try
{
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Dechunking"));
- OperationLogOutput::ProgressBar& DechunkingProgressBar(*ProgressBarPtr);
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Dechunking");
std::atomic<uint64_t> ChunksWritten;
@@ -729,7 +712,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
PrepareFileForScatteredWrite(OpenChunkedFiles.back()->Handle(), ChunkedFileInfo.RawSize);
}
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Dechunk, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Dechunk, (uint32_t)TaskSteps::StepCount);
std::vector<std::atomic<uint8_t>> ChunkWrittenFlags(ChunkOpenFileTargets.size());
@@ -755,7 +738,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
}))
{
std::error_code DummyEc;
- throw std::runtime_error(fmt::format("Failed to decompress chunk {} at offset {} to {}",
+ throw std::runtime_error(fmt::format("Failed to decompress chunk {} at offset {} to '{}'",
CompressedChunkBuffer.DecodeRawHash(),
ChunkTarget.Offset,
PathFromHandle(OutputFile.Handle(), DummyEc)));
@@ -768,8 +751,7 @@ ProjectStoreOperationDownloadAttachments::Execute()
{
Stopwatch DechunkTimer;
- std::atomic<bool> PauseFlag;
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
std::vector<IoHash> LooseChunks(LooseChunksToDownload.begin(), LooseChunksToDownload.end());
@@ -819,26 +801,24 @@ ProjectStoreOperationDownloadAttachments::Execute()
}
});
}
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
ZEN_UNUSED(IsAborted, IsPaused, PendingWork);
std::string Details = fmt::format("{}/{} chunks written", ChunksWritten.load(), ChunkOpenFileTargets.size());
- DechunkingProgressBar.UpdateState(
- {.Task = "Dechunking ",
- .Details = Details,
- .TotalCount = ChunkOpenFileTargets.size(),
- .RemainingCount = ChunkOpenFileTargets.size() - ChunksWritten.load(),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ProgressBar->UpdateState({.Task = "Dechunking ",
+ .Details = Details,
+ .TotalCount = ChunkOpenFileTargets.size(),
+ .RemainingCount = ChunkOpenFileTargets.size() - ChunksWritten.load(),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
});
- DechunkingProgressBar.Finish();
+ ProgressBar->Finish();
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "{} file{} dechunked in {}",
- ChunkedFileInfos.size(),
- ChunkedFileInfos.size() == 1 ? "" : "s",
- NiceTimeSpanMs(DechunkTimer.GetElapsedTimeMs()));
+ ZEN_INFO("{} file{} dechunked in {}",
+ ChunkedFileInfos.size(),
+ ChunkedFileInfos.size() == 1 ? "" : "s",
+ NiceTimeSpanMs(DechunkTimer.GetElapsedTimeMs()));
}
}
}
@@ -853,12 +833,10 @@ ProjectStoreOperationDownloadAttachments::Execute()
throw;
}
{
- Stopwatch VerifyTimer;
- std::unique_ptr<OperationLogOutput::ProgressBar> ProgressBarPtr(m_LogOutput.CreateProgressBar("Verifying"));
- OperationLogOutput::ProgressBar& VerifyProgressBar(*ProgressBarPtr);
+ Stopwatch VerifyTimer;
+ std::unique_ptr<ProgressBase::ProgressBar> ProgressBar = m_Progress.CreateProgressBar("Verifying");
- std::atomic<bool> PauseFlag;
- ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
std::atomic<size_t> DechunkedFilesMoved;
@@ -875,43 +853,41 @@ ProjectStoreOperationDownloadAttachments::Execute()
}
std::filesystem::path ChunkOutputPath = m_Options.AttachmentOutputPath / fmt::format("{}", ChunkedFileInfo.RawHash);
RenameFile(ChunkedFilePath, ChunkOutputPath);
- ZEN_OPERATION_LOG_DEBUG(m_LogOutput, "Moved dechunked file {} to '{}'", ChunkedFileInfo.RawHash, ChunkOutputPath);
+ ZEN_DEBUG("Moved dechunked file {} to '{}'", ChunkedFileInfo.RawHash, ChunkOutputPath);
DechunkedFilesMoved++;
});
}
- Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
+ Work.Wait(m_Progress.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) {
ZEN_UNUSED(IsAborted, IsPaused, PendingWork);
std::string Details = fmt::format("{}/{} files verified", DechunkedFilesMoved.load(), ChunkedFileInfos.size());
- VerifyProgressBar.UpdateState({.Task = "Verifying ",
- .Details = Details,
- .TotalCount = ChunkedFileInfos.size(),
- .RemainingCount = ChunkedFileInfos.size() - DechunkedFilesMoved.load(),
- .Status = OperationLogOutput::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
- false);
+ ProgressBar->UpdateState({.Task = "Verifying ",
+ .Details = Details,
+ .TotalCount = ChunkedFileInfos.size(),
+ .RemainingCount = ChunkedFileInfos.size() - DechunkedFilesMoved.load(),
+ .Status = ProgressBase::ProgressBar::State::CalculateStatus(IsAborted, IsPaused)},
+ false);
});
- VerifyProgressBar.Finish();
+ ProgressBar->Finish();
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Verified {} chunked file{} in {}",
- ChunkedFileInfos.size(),
- ChunkedFileInfos.size() == 1 ? "" : "s",
- NiceTimeSpanMs(VerifyTimer.GetElapsedTimeMs()));
+ ZEN_INFO("Verified {} chunked file{} in {}",
+ ChunkedFileInfos.size(),
+ ChunkedFileInfos.size() == 1 ? "" : "s",
+ NiceTimeSpanMs(VerifyTimer.GetElapsedTimeMs()));
}
}
}
if (!m_Options.IsQuiet)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Downloaded {} attachment{} to '{}' in {}",
- m_AttachmentHashes.size(),
- m_AttachmentHashes.size() == 1 ? "" : "s",
- m_Options.AttachmentOutputPath,
- NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ ZEN_INFO("Downloaded {} attachment{} to '{}' in {}",
+ m_AttachmentHashes.size(),
+ m_AttachmentHashes.size() == 1 ? "" : "s",
+ m_Options.AttachmentOutputPath,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
- m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount);
+ m_Progress.SetLogOperationProgress((uint32_t)TaskSteps::Cleanup, (uint32_t)TaskSteps::StepCount);
}
} // namespace zen
diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp
index 1a9dc10ef..f43f0813a 100644
--- a/src/zenremotestore/projectstore/remoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp
@@ -8,6 +8,8 @@
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
+#include <zencore/logging/broadcastsink.h>
+#include <zencore/logging/logger.h>
#include <zencore/parallelwork.h>
#include <zencore/scopeguard.h>
#include <zencore/stream.h>
@@ -18,8 +20,9 @@
#include <zenremotestore/builds/buildstoragecache.h>
#include <zenremotestore/chunking/chunkedcontent.h>
#include <zenremotestore/chunking/chunkedfile.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenstore/cidstore.h>
+#include <zenutil/logging.h>
+#include <zenutil/progress.h>
#include <numeric>
#include <unordered_map>
@@ -392,7 +395,10 @@ namespace remotestore_impl {
OodleCompressor Compressor,
OodleCompressionLevel CompressionLevel)
{
- ZEN_ASSERT(!IsFile(AttachmentPath));
+ if (IsFile(AttachmentPath))
+ {
+ ZEN_WARN("Temp attachment file already exists at '{}', truncating", AttachmentPath);
+ }
BasicFile CompressedFile;
std::error_code Ec;
CompressedFile.Open(AttachmentPath, BasicFile::Mode::kTruncateDelete, Ec);
@@ -448,6 +454,7 @@ namespace remotestore_impl {
};
CbObject RewriteOplog(
+ LoggerRef InLog,
ProjectStore::Project& Project,
ProjectStore::Oplog& Oplog,
bool IgnoreMissingAttachments,
@@ -456,6 +463,7 @@ namespace remotestore_impl {
std::unordered_map<IoHash, FoundAttachment, IoHash::Hasher>& UploadAttachments, // TODO: Rename to OutUploadAttachments
JobContext* OptionalContext)
{
+ ZEN_SCOPED_LOG(InLog);
size_t OpCount = 0;
CreateDirectories(AttachmentTempPath);
@@ -929,7 +937,6 @@ namespace remotestore_impl {
{
return;
}
- ZEN_ASSERT(UploadAttachment->Size != 0);
if (!UploadAttachment->RawPath.empty())
{
if (UploadAttachment->Size > (MaxChunkEmbedSize * 2))
@@ -1140,31 +1147,51 @@ namespace remotestore_impl {
std::atomic<uint64_t> ChunksCompleteCount = 0;
};
- class JobContextLogOutput : public OperationLogOutput
+ class JobContextSink : public logging::Sink
{
public:
- JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {}
- virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override
+ explicit JobContextSink(JobContext* Context) : m_Context(Context) {}
+
+ void Log(const logging::LogMessage& Msg) override
{
- if (m_OptionalContext)
+ if (m_Context)
{
- fmt::basic_memory_buffer<char, 250> MessageBuffer;
- fmt::vformat_to(fmt::appender(MessageBuffer), Point.FormatString, Args);
- remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size()));
+ m_Context->ReportMessage(Msg.GetPayload());
}
}
- virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); }
- virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); }
- virtual uint32_t GetProgressUpdateDelayMS() override { return 0; }
- virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override
+ void Flush() override {}
+ void SetFormatter(std::unique_ptr<logging::Formatter>) override {}
+
+ private:
+ JobContext* m_Context;
+ };
+
+ class JobContextLogger
+ {
+ public:
+ explicit JobContextLogger(JobContext* OptionalContext)
{
- ZEN_UNUSED(InSubTask);
- return nullptr;
+ if (!OptionalContext)
+ {
+ return;
+ }
+ logging::SinkPtr ContextSink(new JobContextSink(OptionalContext));
+ Ref<logging::BroadcastSink> DefaultSink = GetDefaultBroadcastSink();
+ std::vector<logging::SinkPtr> Sinks;
+ if (DefaultSink)
+ {
+ Sinks.push_back(DefaultSink);
+ }
+ Sinks.push_back(std::move(ContextSink));
+ Ref<logging::BroadcastSink> Broadcast(new logging::BroadcastSink(std::move(Sinks)));
+ m_Log = Ref<logging::Logger>(new logging::Logger("jobcontext", Broadcast));
}
+ LoggerRef Log() const { return m_Log ? LoggerRef(*m_Log) : zen::Log(); }
+
private:
- JobContext* m_OptionalContext;
+ Ref<logging::Logger> m_Log;
};
void DownloadAndSaveBlockChunks(LoadOplogContext& Context,
@@ -1185,6 +1212,7 @@ namespace remotestore_impl {
&LoadAttachmentsTimer,
&DownloadStartMS](std::atomic<bool>& AbortFlag) {
ZEN_TRACE_CPU("DownloadBlockChunks");
+ ZEN_SCOPED_LOG(Context.Log);
if (AbortFlag)
{
@@ -1300,6 +1328,7 @@ namespace remotestore_impl {
&AllNeededPartialChunkHashesLookup,
ChunkDownloadedFlags](std::atomic<bool>& AbortFlag) {
ZEN_TRACE_CPU("DownloadBlock");
+ ZEN_SCOPED_LOG(Context.Log);
if (AbortFlag)
{
@@ -1366,6 +1395,7 @@ namespace remotestore_impl {
&AllNeededPartialChunkHashesLookup,
ChunkDownloadedFlags,
Bytes = std::move(BlobBuffer)](std::atomic<bool>& AbortFlag) {
+ ZEN_SCOPED_LOG(Context.Log);
if (AbortFlag)
{
return;
@@ -1715,6 +1745,7 @@ namespace remotestore_impl {
ChunkDownloadedFlags,
RetriesLeft](std::atomic<bool>& AbortFlag) {
ZEN_TRACE_CPU("DownloadBlockRanges");
+ ZEN_SCOPED_LOG(Context.Log);
try
{
uint64_t Unset = (std::uint64_t)-1;
@@ -1760,6 +1791,7 @@ namespace remotestore_impl {
OffsetAndLengths =
std::vector<std::pair<uint64_t, uint64_t>>(OffsetAndLengths.begin(), OffsetAndLengths.end())](
std::atomic<bool>& AbortFlag) {
+ ZEN_SCOPED_LOG(Context.Log);
try
{
ZEN_ASSERT(BlockPayload.Size() > 0);
@@ -1972,6 +2004,7 @@ namespace remotestore_impl {
Context.NetworkWorkerPool,
[&Context, &AttachmentWork, RawHash, &LoadAttachmentsTimer, &DownloadStartMS, &Info](std::atomic<bool>& AbortFlag) {
ZEN_TRACE_CPU("DownloadAttachment");
+ ZEN_SCOPED_LOG(Context.Log);
if (AbortFlag)
{
@@ -2061,7 +2094,8 @@ namespace remotestore_impl {
WorkerThreadPool::EMode::EnableBacklog);
};
- void AsyncCreateBlock(ParallelWork& Work,
+ void AsyncCreateBlock(LoggerRef InLog,
+ ParallelWork& Work,
WorkerThreadPool& WorkerPool,
std::vector<std::pair<IoHash, FetchChunkFunc>>&& ChunksInBlock,
RwLock& SectionsLock,
@@ -2071,9 +2105,10 @@ namespace remotestore_impl {
JobContext* OptionalContext)
{
Work.ScheduleWork(WorkerPool,
- [&Blocks, &SectionsLock, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, OptionalContext](
+ [InLog, &Blocks, &SectionsLock, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, OptionalContext](
std::atomic<bool>& AbortFlag) mutable {
ZEN_TRACE_CPU("CreateBlock");
+ ZEN_SCOPED_LOG(InLog);
if (remotestore_impl::IsCancelled(OptionalContext))
{
@@ -2452,7 +2487,8 @@ GetBlocksFromOplog(CbObjectView ContainerObject, std::span<const IoHash> Include
}
CbObject
-BuildContainer(CidStore& ChunkStore,
+BuildContainer(LoggerRef InLog,
+ CidStore& ChunkStore,
ProjectStore::Project& Project,
ProjectStore::Oplog& Oplog,
size_t MaxBlockSize,
@@ -2472,7 +2508,8 @@ BuildContainer(CidStore& ChunkStore,
{
using namespace std::literals;
- std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(OptionalContext));
+ ZEN_SCOPED_LOG(InLog);
+ remotestore_impl::JobContextLogger JobContextOutput(OptionalContext);
Stopwatch Timer;
@@ -2485,7 +2522,8 @@ BuildContainer(CidStore& ChunkStore,
size_t TotalOpCount = Oplog.GetOplogEntryCount();
Stopwatch RewriteOplogTimer;
- CbObject SectionOps = remotestore_impl::RewriteOplog(Project,
+ CbObject SectionOps = remotestore_impl::RewriteOplog(InLog,
+ Project,
Oplog,
IgnoreMissingAttachments,
EmbedLooseFiles,
@@ -2605,7 +2643,7 @@ BuildContainer(CidStore& ChunkStore,
std::vector<uint32_t> UnusedChunkIndexes;
ReuseBlocksStatistics ReuseBlocksStats;
- ReusedBlockIndexes = FindReuseBlocks(*LogOutput,
+ ReusedBlockIndexes = FindReuseBlocks(JobContextOutput.Log(),
/*BlockReuseMinPercentLimit*/ 80,
/*IsVerbose*/ false,
ReuseBlocksStats,
@@ -2749,7 +2787,8 @@ BuildContainer(CidStore& ChunkStore,
.MaxChunkEmbedSize = MaxChunkEmbedSize,
.IsCancelledFunc = [OptionalContext]() { return remotestore_impl::IsCancelled(OptionalContext); }});
- auto OnNewBlock = [&Work,
+ auto OnNewBlock = [&Log,
+ &Work,
&WorkerPool,
BuildBlocks,
&BlockCreateProgressTimer,
@@ -2774,7 +2813,8 @@ BuildContainer(CidStore& ChunkStore,
size_t BlockIndex = remotestore_impl::AddBlock(BlocksLock, Blocks);
if (BuildBlocks)
{
- remotestore_impl::AsyncCreateBlock(Work,
+ remotestore_impl::AsyncCreateBlock(Log(),
+ Work,
WorkerPool,
std::move(ChunksInBlock),
BlocksLock,
@@ -3007,6 +3047,17 @@ BuildContainer(CidStore& ChunkStore,
return {};
}
+ // Reused blocks were not composed (their chunks were erased from UploadAttachments) but must
+ // still appear in the container so that a fresh receiver knows to download them.
+ if (BuildBlocks)
+ {
+ for (size_t KnownBlockIndex : ReusedBlockIndexes)
+ {
+ const ChunkBlockDescription& Reused = KnownBlocks[KnownBlockIndex];
+ Blocks.push_back(Reused);
+ }
+ }
+
CbObjectWriter OplogContainerWriter;
RwLock::SharedLockScope _(BlocksLock);
OplogContainerWriter.AddBinary("ops"sv, CompressedOpsSection.GetCompressed().Flatten().AsIoBuffer());
@@ -3096,7 +3147,8 @@ BuildContainer(CidStore& ChunkStore,
}
CbObject
-BuildContainer(CidStore& ChunkStore,
+BuildContainer(LoggerRef InLog,
+ CidStore& ChunkStore,
ProjectStore::Project& Project,
ProjectStore::Oplog& Oplog,
WorkerThreadPool& WorkerPool,
@@ -3112,7 +3164,8 @@ BuildContainer(CidStore& ChunkStore,
const std::function<void(std::vector<std::pair<IoHash, FetchChunkFunc>>&&)>& OnBlockChunks,
bool EmbedLooseFiles)
{
- return BuildContainer(ChunkStore,
+ return BuildContainer(InLog,
+ ChunkStore,
Project,
Oplog,
MaxBlockSize,
@@ -3132,7 +3185,8 @@ BuildContainer(CidStore& ChunkStore,
}
void
-SaveOplog(CidStore& ChunkStore,
+SaveOplog(LoggerRef InLog,
+ CidStore& ChunkStore,
RemoteProjectStore& RemoteStore,
ProjectStore::Project& Project,
ProjectStore::Oplog& Oplog,
@@ -3149,6 +3203,7 @@ SaveOplog(CidStore& ChunkStore,
{
using namespace std::literals;
+ ZEN_SCOPED_LOG(InLog);
Stopwatch Timer;
remotestore_impl::UploadInfo Info;
@@ -3168,8 +3223,8 @@ SaveOplog(CidStore& ChunkStore,
std::unordered_map<IoHash, remotestore_impl::CreatedBlock, IoHash::Hasher> CreatedBlocks;
tsl::robin_map<IoHash, TGetAttachmentBufferFunc, IoHash::Hasher> LooseLargeFiles;
- auto MakeTempBlock = [AttachmentTempPath, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock,
- ChunkBlockDescription&& Block) {
+ auto MakeTempBlock = [&Log, AttachmentTempPath, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock,
+ ChunkBlockDescription&& Block) {
std::filesystem::path BlockPath = AttachmentTempPath;
BlockPath.append(Block.BlockHash.ToHexString());
IoBuffer BlockBuffer = WriteToTempFile(std::move(CompressedBlock).GetCompressed(), BlockPath);
@@ -3180,8 +3235,8 @@ SaveOplog(CidStore& ChunkStore,
ZEN_DEBUG("Saved temp block to '{}', {}", AttachmentTempPath, NiceBytes(BlockSize));
};
- auto UploadBlock = [&RemoteStore, &RemoteStoreInfo, &Info, OptionalContext](CompressedBuffer&& CompressedBlock,
- ChunkBlockDescription&& Block) {
+ auto UploadBlock = [&Log, &RemoteStore, &RemoteStoreInfo, &Info, OptionalContext](CompressedBuffer&& CompressedBlock,
+ ChunkBlockDescription&& Block) {
IoHash BlockHash = Block.BlockHash;
uint64_t CompressedSize = CompressedBlock.GetCompressedSize();
RemoteProjectStore::SaveAttachmentResult Result =
@@ -3201,13 +3256,13 @@ SaveOplog(CidStore& ChunkStore,
};
std::vector<std::vector<std::pair<IoHash, FetchChunkFunc>>> BlockChunks;
- auto OnBlockChunks = [&BlockChunks](std::vector<std::pair<IoHash, FetchChunkFunc>>&& Chunks) {
+ auto OnBlockChunks = [&Log, &BlockChunks](std::vector<std::pair<IoHash, FetchChunkFunc>>&& Chunks) {
BlockChunks.push_back({std::make_move_iterator(Chunks.begin()), std::make_move_iterator(Chunks.end())});
ZEN_DEBUG("Found {} block chunks", Chunks.size());
};
- auto OnLargeAttachment = [&AttachmentsLock, &LargeAttachments, &LooseLargeFiles](const IoHash& AttachmentHash,
- TGetAttachmentBufferFunc&& GetBufferFunc) {
+ auto OnLargeAttachment = [&Log, &AttachmentsLock, &LargeAttachments, &LooseLargeFiles](const IoHash& AttachmentHash,
+ TGetAttachmentBufferFunc&& GetBufferFunc) {
{
RwLock::ExclusiveLockScope _(AttachmentsLock);
LargeAttachments.insert(AttachmentHash);
@@ -3286,7 +3341,8 @@ SaveOplog(CidStore& ChunkStore,
}
}
- CbObject OplogContainerObject = BuildContainer(ChunkStore,
+ CbObject OplogContainerObject = BuildContainer(InLog,
+ ChunkStore,
Project,
Oplog,
MaxBlockSize,
@@ -3694,7 +3750,8 @@ LoadOplog(LoadOplogContext&& Context)
{
using namespace std::literals;
- std::unique_ptr<OperationLogOutput> LogOutput(std::make_unique<remotestore_impl::JobContextLogOutput>(Context.OptionalJobContext));
+ ZEN_SCOPED_LOG(Context.Log);
+ remotestore_impl::JobContextLogger JobContextOutput(Context.OptionalJobContext);
remotestore_impl::DownloadInfo Info;
@@ -3985,7 +4042,7 @@ LoadOplog(LoadOplogContext&& Context)
ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size());
ChunkBlockAnalyser PartialAnalyser(
- *LogOutput,
+ JobContextOutput.Log(),
BlockDescriptions.Blocks,
ChunkBlockAnalyser::Options{.IsQuiet = false,
.IsVerbose = false,
@@ -4108,10 +4165,10 @@ LoadOplog(LoadOplogContext&& Context)
std::filesystem::path TempFileName = TempFilePath / Chunked.RawHash.ToHexString();
DechunkWork.ScheduleWork(
Context.WorkerPool,
- [&Context, TempFileName, &FilesToDechunk, ChunkedIndex, &Info](std::atomic<bool>& AbortFlag) {
+ [&Log, &Context, TempFileName, &FilesToDechunk, ChunkedIndex, &Info](std::atomic<bool>& AbortFlag) {
ZEN_TRACE_CPU("DechunkAttachment");
- auto _ = MakeGuard([&TempFileName] {
+ auto _ = MakeGuard([&Log, &TempFileName] {
std::error_code Ec;
if (IsFile(TempFileName, Ec))
{
@@ -4712,7 +4769,8 @@ TEST_CASE_TEMPLATE("project.store.export",
WorkerThreadPool& NetworkPool = Pools.NetworkPool;
WorkerThreadPool& WorkerPool = Pools.WorkerPool;
- SaveOplog(CidStore,
+ SaveOplog(Log(),
+ CidStore,
*RemoteStore,
*Project.Get(),
*Oplog,
@@ -4732,7 +4790,8 @@ TEST_CASE_TEMPLATE("project.store.export",
CapturingJobContext Ctx;
auto DoLoad = [&](bool Force, bool Clean) {
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.OptionalCache = nullptr,
.CacheBuildId = Oid::Zero,
@@ -4793,7 +4852,8 @@ SetupExportStore(CidStore& CidStore,
/*.ForceEnableTempBlocks =*/false};
std::shared_ptr<RemoteProjectStore> RemoteStore = CreateFileRemoteStore(Log(), Options);
- SaveOplog(CidStore,
+ SaveOplog(Log(),
+ CidStore,
*RemoteStore,
Project,
*Oplog,
@@ -4856,7 +4916,8 @@ SetupPartialBlockExportStore(WorkerThreadPool& NetworkPool, WorkerThreadPool& Wo
/*.ForceDisableBlocks =*/false,
/*.ForceEnableTempBlocks =*/false};
std::shared_ptr<RemoteProjectStore> RemoteStore = CreateFileRemoteStore(Log(), Options);
- SaveOplog(LocalCidStore,
+ SaveOplog(Log(),
+ LocalCidStore,
*RemoteStore,
*LocalProject,
*Oplog,
@@ -5024,7 +5085,8 @@ TEST_CASE("project.store.import.context_settings")
bool PopulateCache,
bool ForceDownload) -> void {
Ref<ProjectStore::Oplog> ImportOplog = ImportProject->NewOplog(fmt::format("import_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *RemoteStore,
.OptionalCache = OptCache,
.CacheBuildId = CacheBuildId,
@@ -5131,7 +5193,8 @@ TEST_CASE("project.store.import.context_settings")
// StoreMaxRangeCountPerRequest=128 -> all three ranges sent in one LoadAttachmentRanges call.
Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_multi_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *PartialRemoteStore,
.OptionalCache = nullptr,
.CacheBuildId = CacheBuildId,
@@ -5163,7 +5226,8 @@ TEST_CASE("project.store.import.context_settings")
SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash);
Ref<ProjectStore::Oplog> PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_single_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *PartialRemoteStore,
.OptionalCache = nullptr,
.CacheBuildId = CacheBuildId,
@@ -5194,7 +5258,8 @@ TEST_CASE("project.store.import.context_settings")
// Phase 1: full block download from remote populates the cache.
{
Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p1_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *PartialRemoteStore,
.OptionalCache = Cache.get(),
.CacheBuildId = CacheBuildId,
@@ -5226,7 +5291,8 @@ TEST_CASE("project.store.import.context_settings")
SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash);
Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p2_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = Phase2CidStore,
.RemoteStore = *PartialRemoteStore,
.OptionalCache = Cache.get(),
.CacheBuildId = CacheBuildId,
@@ -5259,7 +5325,8 @@ TEST_CASE("project.store.import.context_settings")
// Phase 1: full block download from remote into cache.
{
Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p1_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *PartialRemoteStore,
.OptionalCache = Cache.get(),
.CacheBuildId = CacheBuildId,
@@ -5291,7 +5358,8 @@ TEST_CASE("project.store.import.context_settings")
SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash);
Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p2_{}", OpJobIndex++), {});
- LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = Phase2CidStore,
.RemoteStore = *PartialRemoteStore,
.OptionalCache = Cache.get(),
.CacheBuildId = CacheBuildId,
@@ -5373,7 +5441,8 @@ RunSaveOplog(CidStore& CidStore,
{
*OutRemoteStore = RemoteStore;
}
- SaveOplog(CidStore,
+ SaveOplog(Log(),
+ CidStore,
*RemoteStore,
Project,
Oplog,
@@ -5476,7 +5545,8 @@ TEST_CASE("project.store.embed_loose_files_true")
/*ForceDisableBlocks=*/false,
&RemoteStore);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_embed_true_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -5530,7 +5600,8 @@ TEST_CASE("project.store.embed_loose_files_false" * doctest::skip()) // superse
&RemoteStore);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_embed_false_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -5693,7 +5764,8 @@ TEST_CASE("project.store.export.large_file_attachment_direct")
&RemoteStore);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_large_direct_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -5750,7 +5822,8 @@ TEST_CASE("project.store.export.large_file_attachment_via_temp")
&RemoteStore);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_large_via_temp_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -5804,7 +5877,8 @@ TEST_CASE("project.store.export.large_chunk_from_cidstore")
&RemoteStore);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_large_cid_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -5867,7 +5941,8 @@ TEST_CASE("project.store.export.block_reuse")
BlockHashesAfterFirst.push_back(B.BlockHash);
}
- SaveOplog(CidStore,
+ SaveOplog(Log(),
+ CidStore,
*RemoteStore,
*Project,
*Oplog,
@@ -5944,7 +6019,8 @@ TEST_CASE("project.store.export.max_chunks_per_block")
CHECK(KnownBlocks.Blocks.size() >= 2);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_max_chunks_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6027,7 +6103,8 @@ TEST_CASE("project.store.export.max_data_per_block")
CHECK(KnownBlocks.Blocks.size() >= 2);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_max_data_per_block_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6155,7 +6232,8 @@ TEST_CASE("project.store.embed_loose_files_zero_data_hash")
&RemoteStore);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_zero_data_hash_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6209,7 +6287,8 @@ TEST_CASE("project.store.embed_loose_files_already_resolved")
&RemoteStore1);
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_already_resolved_import", {});
- LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore1,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6296,7 +6375,8 @@ TEST_CASE("project.store.import.missing_attachment")
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_missing_att_throw", {});
REQUIRE(ImportOplog);
CapturingJobContext Ctx;
- CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6313,7 +6393,8 @@ TEST_CASE("project.store.import.missing_attachment")
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_missing_att_ignore", {});
REQUIRE(ImportOplog);
CapturingJobContext Ctx;
- CHECK_NOTHROW(LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ CHECK_NOTHROW(LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6358,7 +6439,8 @@ TEST_CASE("project.store.import.error.load_container_failure")
REQUIRE(ImportOplog);
CapturingJobContext Ctx;
- CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ CHECK_THROWS_AS(LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -6785,6 +6867,7 @@ TEST_CASE("buildcontainer.public_overload_smoke")
std::atomic<int> BlockCallCount{0};
CbObject Container = BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -6828,6 +6911,7 @@ TEST_CASE("buildcontainer.build_blocks_false_on_block_chunks")
std::atomic<int> BlockChunksCallCount{0};
CbObject Container = BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -6893,6 +6977,7 @@ TEST_CASE("buildcontainer.ignore_missing_binary_attachment_warn")
{
CapturingJobContext Ctx;
BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -6916,6 +7001,7 @@ TEST_CASE("buildcontainer.ignore_missing_binary_attachment_warn")
SUBCASE("throw")
{
CHECK_THROWS(BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -6967,6 +7053,7 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn")
{
CapturingJobContext Ctx;
BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -6990,6 +7077,7 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn")
SUBCASE("throw")
{
CHECK_THROWS(BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7008,6 +7096,61 @@ TEST_CASE("buildcontainer.ignore_missing_file_attachment_warn")
}
}
+TEST_CASE("buildcontainer.zero_byte_file_attachment")
+{
+ // A zero-byte file on disk is a valid attachment. BuildContainer must process
+ // it without hitting ZEN_ASSERT(UploadAttachment->Size != 0) in
+ // ResolveAttachments. The empty file flows through the compress-inline path
+ // and becomes a LooseUploadAttachment with raw size 0.
+ using namespace projectstore_testutils;
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ std::unique_ptr<ProjectStore> ProjectStoreDummy;
+ Ref<ProjectStore::Project> Project = MakeTestProject(CidStore, Gc, TempDir.Path(), ProjectStoreDummy);
+
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ auto FileAtts = CreateFileAttachments(RootDir, std::initializer_list<size_t>{512});
+
+ Ref<ProjectStore::Oplog> Oplog = Project->NewOplog("bc_zero_byte_file", {});
+ REQUIRE(Oplog);
+ Oplog->AppendNewOplogEntry(CreateFilesOplogPackage(Oid::NewOid(), RootDir, FileAtts));
+
+ // Truncate the file to zero bytes after the oplog entry is created.
+ // The file still exists on disk so RewriteOplog's IsFile() check passes,
+ // but MakeFromFile returns a zero-size buffer.
+ std::filesystem::resize_file(FileAtts[0].second, 0);
+
+ WorkerThreadPool WorkerPool(GetWorkerCount());
+
+ CbObject Container = BuildContainer(
+ Log(),
+ CidStore,
+ *Project,
+ *Oplog,
+ WorkerPool,
+ 64u * 1024u,
+ 1000,
+ 32u * 1024u,
+ 64u * 1024u * 1024u,
+ /*BuildBlocks=*/true,
+ /*IgnoreMissingAttachments=*/false,
+ /*AllowChunking=*/true,
+ [](CompressedBuffer&&, ChunkBlockDescription&&) {},
+ [](const IoHash&, TGetAttachmentBufferFunc&&) {},
+ [](std::vector<std::pair<IoHash, FetchChunkFunc>>&&) {},
+ /*EmbedLooseFiles=*/true);
+
+ CHECK(Container.GetSize() > 0);
+
+ // The zero-byte attachment is packed into a block via the compress-inline path.
+ CbArrayView Blocks = Container["blocks"sv].AsArrayView();
+ CHECK(Blocks.Num() > 0);
+}
+
TEST_CASE("buildcontainer.embed_loose_files_false_no_rewrite")
{
// EmbedLooseFiles=false: RewriteOp is skipped for file-op entries; they pass through
@@ -7030,6 +7173,7 @@ TEST_CASE("buildcontainer.embed_loose_files_false_no_rewrite")
WorkerThreadPool WorkerPool(GetWorkerCount());
CbObject Container = BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7080,6 +7224,7 @@ TEST_CASE("buildcontainer.allow_chunking_false")
{
std::atomic<int> LargeAttachmentCallCount{0};
BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7103,6 +7248,7 @@ TEST_CASE("buildcontainer.allow_chunking_false")
// Chunking branch in FindChunkSizes is taken, but the ~4 KB chunk still exceeds MaxChunkEmbedSize -> OnLargeAttachment.
std::atomic<int> LargeAttachmentCallCount{0};
BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7144,6 +7290,7 @@ TEST_CASE("buildcontainer.async_on_block_exception_propagates")
WorkerThreadPool WorkerPool(GetWorkerCount());
CHECK_THROWS_AS(BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7184,6 +7331,7 @@ TEST_CASE("buildcontainer.on_large_attachment_exception_propagates")
WorkerThreadPool WorkerPool(GetWorkerCount());
CHECK_THROWS_AS(BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7226,6 +7374,7 @@ TEST_CASE("buildcontainer.context_cancellation_aborts")
Ctx.m_Cancel = true;
CHECK_NOTHROW(BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7265,6 +7414,7 @@ TEST_CASE("buildcontainer.context_progress_reporting")
CapturingJobContext Ctx;
BuildContainer(
+ Log(),
CidStore,
*Project,
*Oplog,
@@ -7428,7 +7578,8 @@ TEST_CASE("loadoplog.missing_block_attachment_ignored")
CapturingJobContext Ctx;
Ref<ProjectStore::Oplog> ImportOplog = Project->NewOplog("oplog_missing_block_import", {});
- CHECK_NOTHROW(LoadOplog(LoadOplogContext{.ChunkStore = CidStore,
+ CHECK_NOTHROW(LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = CidStore,
.RemoteStore = *RemoteStore,
.Oplog = *ImportOplog,
.NetworkWorkerPool = NetworkPool,
@@ -7501,7 +7652,8 @@ TEST_CASE("loadoplog.clean_oplog_with_populated_cache")
{
Ref<ProjectStore::Oplog> Phase1Oplog = ImportProject->NewOplog("oplog_clean_cache_p1", {});
- LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *RemoteStore,
.OptionalCache = Cache.get(),
.CacheBuildId = CacheBuildId,
@@ -7517,7 +7669,8 @@ TEST_CASE("loadoplog.clean_oplog_with_populated_cache")
{
Ref<ProjectStore::Oplog> Phase2Oplog = ImportProject->NewOplog("oplog_clean_cache_p2", {});
- CHECK_NOTHROW(LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore,
+ CHECK_NOTHROW(LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
.RemoteStore = *RemoteStore,
.OptionalCache = Cache.get(),
.CacheBuildId = CacheBuildId,
@@ -7532,6 +7685,158 @@ TEST_CASE("loadoplog.clean_oplog_with_populated_cache")
}
}
+TEST_CASE("project.store.export.block_reuse_fresh_receiver")
+{
+ // Regression test: after a second export that reuses existing blocks, a fresh import must still
+ // receive all chunks. The bug: FindReuseBlocks erases reused-block chunks from UploadAttachments,
+ // but never adds the reused blocks to the container's "blocks" section. A fresh receiver then
+ // silently misses those chunks because ParseOplogContainer never sees them.
+ using namespace projectstore_testutils;
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+ ScopedTemporaryDirectory ExportDir;
+
+ // -- Export side ----------------------------------------------------------
+ GcManager ExportGc;
+ CidStore ExportCidStore(ExportGc);
+ CidStoreConfiguration ExportCidConfig = {.RootDirectory = TempDir.Path() / "export_cas",
+ .TinyValueThreshold = 1024,
+ .HugeValueThreshold = 4096};
+ ExportCidStore.Initialize(ExportCidConfig);
+
+ std::filesystem::path ExportBasePath = TempDir.Path() / "export_projectstore";
+ ProjectStore ExportProjectStore(ExportCidStore, ExportBasePath, ExportGc, ProjectStore::Configuration{});
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+ std::filesystem::path ProjectRootDir = TempDir.Path() / "game";
+ std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject";
+ Ref<ProjectStore::Project> ExportProject(ExportProjectStore.NewProject(ExportBasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+
+ // 20 KB with None encoding: compressed ~ 20 KB < MaxChunkEmbedSize (32 KB) -> packed into blocks.
+ Ref<ProjectStore::Oplog> Oplog = ExportProject->NewOplog("oplog_reuse_rt", {});
+ REQUIRE(Oplog);
+ Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(
+ Oid::NewOid(),
+ CreateAttachments(std::initializer_list<size_t>{20u * 1024u, 20u * 1024u}, OodleCompressionLevel::None)));
+
+ TestWorkerPools Pools;
+ WorkerThreadPool& NetworkPool = Pools.NetworkPool;
+ WorkerThreadPool& WorkerPool = Pools.WorkerPool;
+
+ constexpr size_t MaxBlockSize = 64u * 1024u;
+ constexpr size_t MaxChunksPerBlock = 1000;
+ constexpr size_t MaxChunkEmbedSize = 32u * 1024u;
+ constexpr size_t ChunkFileSizeLimit = 64u * 1024u * 1024u;
+
+ // First export: creates blocks on disk.
+ FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize,
+ .MaxChunksPerBlock = MaxChunksPerBlock,
+ .MaxChunkEmbedSize = MaxChunkEmbedSize,
+ .ChunkFileSizeLimit = ChunkFileSizeLimit},
+ /*.FolderPath =*/ExportDir.Path(),
+ /*.Name =*/std::string("oplog_reuse_rt"),
+ /*.OptionalBaseName =*/std::string(),
+ /*.ForceDisableBlocks =*/false,
+ /*.ForceEnableTempBlocks =*/false};
+
+ std::shared_ptr<RemoteProjectStore> RemoteStore = CreateFileRemoteStore(Log(), Options);
+ SaveOplog(Log(),
+ ExportCidStore,
+ *RemoteStore,
+ *ExportProject,
+ *Oplog,
+ NetworkPool,
+ WorkerPool,
+ MaxBlockSize,
+ MaxChunksPerBlock,
+ MaxChunkEmbedSize,
+ ChunkFileSizeLimit,
+ /*EmbedLooseFiles*/ true,
+ /*ForceUpload*/ false,
+ /*IgnoreMissingAttachments*/ false,
+ /*OptionalContext*/ nullptr);
+
+ // Verify first export produced blocks.
+ RemoteProjectStore::GetKnownBlocksResult KnownAfterFirst = RemoteStore->GetKnownBlocks();
+ REQUIRE(!KnownAfterFirst.Blocks.empty());
+
+ // Second export to the SAME store: triggers block reuse via GetKnownBlocks.
+ SaveOplog(Log(),
+ ExportCidStore,
+ *RemoteStore,
+ *ExportProject,
+ *Oplog,
+ NetworkPool,
+ WorkerPool,
+ MaxBlockSize,
+ MaxChunksPerBlock,
+ MaxChunkEmbedSize,
+ ChunkFileSizeLimit,
+ /*EmbedLooseFiles*/ true,
+ /*ForceUpload*/ false,
+ /*IgnoreMissingAttachments*/ false,
+ /*OptionalContext*/ nullptr);
+
+ // Verify the container has no duplicate block entries.
+ {
+ RemoteProjectStore::LoadContainerResult ContainerResult = RemoteStore->LoadContainer();
+ REQUIRE(ContainerResult.ErrorCode == 0);
+ std::vector<IoHash> BlockHashes = GetBlockHashesFromOplog(ContainerResult.ContainerObject);
+ REQUIRE(!BlockHashes.empty());
+ std::unordered_set<IoHash, IoHash::Hasher> UniqueBlockHashes(BlockHashes.begin(), BlockHashes.end());
+ CHECK(UniqueBlockHashes.size() == BlockHashes.size());
+ }
+
+ // Collect all attachment hashes referenced by the oplog ops.
+ std::unordered_set<IoHash, IoHash::Hasher> ExpectedHashes;
+ Oplog->IterateOplogWithKey([&](int, const Oid&, CbObjectView Op) {
+ Op.IterateAttachments([&](CbFieldView FieldView) { ExpectedHashes.insert(FieldView.AsAttachment()); });
+ });
+ REQUIRE(!ExpectedHashes.empty());
+
+ // -- Import side (fresh, empty CAS) --------------------------------------
+ GcManager ImportGc;
+ CidStore ImportCidStore(ImportGc);
+ CidStoreConfiguration ImportCidConfig = {.RootDirectory = TempDir.Path() / "import_cas",
+ .TinyValueThreshold = 1024,
+ .HugeValueThreshold = 4096};
+ ImportCidStore.Initialize(ImportCidConfig);
+
+ std::filesystem::path ImportBasePath = TempDir.Path() / "import_projectstore";
+ ProjectStore ImportProjectStore(ImportCidStore, ImportBasePath, ImportGc, ProjectStore::Configuration{});
+ Ref<ProjectStore::Project> ImportProject(ImportProjectStore.NewProject(ImportBasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+
+ Ref<ProjectStore::Oplog> ImportOplog = ImportProject->NewOplog("oplog_reuse_rt_import", {});
+ REQUIRE(ImportOplog);
+
+ LoadOplog(LoadOplogContext{.Log = Log(),
+ .ChunkStore = ImportCidStore,
+ .RemoteStore = *RemoteStore,
+ .Oplog = *ImportOplog,
+ .NetworkWorkerPool = NetworkPool,
+ .WorkerPool = WorkerPool,
+ .ForceDownload = true,
+ .IgnoreMissingAttachments = false,
+ .PartialBlockRequestMode = EPartialBlockRequestMode::All});
+
+ // Every attachment hash from the original oplog must be present in the import CAS.
+ for (const IoHash& Hash : ExpectedHashes)
+ {
+ CHECK_MESSAGE(ImportCidStore.ContainsChunk(Hash), "Missing chunk after import: ", Hash);
+ }
+}
+
TEST_SUITE_END();
#endif // ZEN_WITH_TESTS
diff --git a/src/zenremotestore/zenremotestore.cpp b/src/zenremotestore/zenremotestore.cpp
index 0b205b296..9642f8470 100644
--- a/src/zenremotestore/zenremotestore.cpp
+++ b/src/zenremotestore/zenremotestore.cpp
@@ -9,7 +9,6 @@
#include <zenremotestore/chunking/chunkedcontent.h>
#include <zenremotestore/chunking/chunkedfile.h>
#include <zenremotestore/chunking/chunkingcache.h>
-#include <zenremotestore/filesystemutils.h>
#include <zenremotestore/projectstore/remoteprojectstore.h>
#if ZEN_WITH_TESTS
@@ -27,7 +26,6 @@ zenremotestore_forcelinktests()
chunkedcontent_forcelink();
chunkedfile_forcelink();
chunkingcache_forcelink();
- filesystemutils_forcelink();
remoteprojectstore_forcelink();
}
diff --git a/src/zens3-testbed/main.cpp b/src/zens3-testbed/main.cpp
deleted file mode 100644
index 4cd6b411f..000000000
--- a/src/zens3-testbed/main.cpp
+++ /dev/null
@@ -1,526 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-// Simple test bed for exercising the zens3 module against a real S3 bucket.
-//
-// Usage:
-// zens3-testbed --bucket <name> --region <region> [command] [args...]
-//
-// Credentials are read from environment variables:
-// AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
-//
-// Commands:
-// put <key> <file> Upload a local file
-// get <key> [file] Download an object (prints to stdout if no file given)
-// head <key> Check if object exists, show metadata
-// delete <key> Delete an object
-// list [prefix] List objects with optional prefix
-// multipart-put <key> <file> [part-size-mb] Upload via multipart
-// roundtrip <key> Upload test data, download, verify, delete
-
-#include <zenutil/cloud/imdscredentials.h>
-#include <zenutil/cloud/s3client.h>
-
-#include <zencore/except_fmt.h>
-#include <zencore/filesystem.h>
-#include <zencore/iobuffer.h>
-#include <zencore/logging.h>
-#include <zencore/string.h>
-
-#include <zencore/memory/newdelete.h>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <fmt/format.h>
-#include <cxxopts.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-#include <cstdlib>
-#include <fstream>
-#include <iostream>
-
-namespace {
-
-using namespace zen;
-
-std::string
-GetEnvVar(const char* Name)
-{
- const char* Value = std::getenv(Name);
- return Value ? std::string(Value) : std::string();
-}
-
-IoBuffer
-ReadFileToBuffer(const std::filesystem::path& Path)
-{
- return zen::ReadFile(Path).Flatten();
-}
-
-void
-WriteBufferToFile(const IoBuffer& Buffer, const std::filesystem::path& Path)
-{
- std::ofstream File(Path, std::ios::binary);
- if (!File)
- {
- throw zen::runtime_error("failed to open '{}' for writing", Path.string());
- }
- File.write(reinterpret_cast<const char*>(Buffer.GetData()), static_cast<std::streamsize>(Buffer.GetSize()));
-}
-
-S3Client
-CreateClient(const cxxopts::ParseResult& Args)
-{
- S3ClientOptions Options;
- Options.BucketName = Args["bucket"].as<std::string>();
- Options.Region = Args["region"].as<std::string>();
-
- if (Args.count("imds"))
- {
- // Use IMDS credential provider for EC2 instances
- ImdsCredentialProviderOptions ImdsOpts;
- if (Args.count("imds-endpoint"))
- {
- ImdsOpts.Endpoint = Args["imds-endpoint"].as<std::string>();
- }
- Options.CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider(ImdsOpts));
- }
- else
- {
- std::string AccessKey = GetEnvVar("AWS_ACCESS_KEY_ID");
- std::string SecretKey = GetEnvVar("AWS_SECRET_ACCESS_KEY");
- std::string SessionToken = GetEnvVar("AWS_SESSION_TOKEN");
-
- if (AccessKey.empty() || SecretKey.empty())
- {
- throw zen::runtime_error("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables must be set");
- }
-
- Options.Credentials.AccessKeyId = std::move(AccessKey);
- Options.Credentials.SecretAccessKey = std::move(SecretKey);
- Options.Credentials.SessionToken = std::move(SessionToken);
- }
-
- if (Args.count("endpoint"))
- {
- Options.Endpoint = Args["endpoint"].as<std::string>();
- }
-
- if (Args.count("path-style"))
- {
- Options.PathStyle = true;
- }
-
- if (Args.count("timeout"))
- {
- Options.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000);
- }
-
- return S3Client(Options);
-}
-
-int
-CmdPut(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 3)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... put <key> <file>\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
- const auto& FilePath = Positional[2];
-
- IoBuffer Content = ReadFileToBuffer(FilePath);
- fmt::print("Uploading '{}' ({} bytes) to s3://{}/{}\n", FilePath, Content.GetSize(), Client.BucketName(), Key);
-
- S3Result Result = Client.PutObject(Key, Content);
- if (!Result)
- {
- fmt::print(stderr, "PUT failed: {}\n", Result.Error);
- return 1;
- }
-
- fmt::print("OK\n");
- return 0;
-}
-
-int
-CmdGet(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 2)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... get <key> [file]\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
-
- S3GetObjectResult Result = Client.GetObject(Key);
- if (!Result)
- {
- fmt::print(stderr, "GET failed: {}\n", Result.Error);
- return 1;
- }
-
- if (Positional.size() >= 3)
- {
- const auto& FilePath = Positional[2];
- WriteBufferToFile(Result.Content, FilePath);
- fmt::print("Downloaded {} bytes to '{}'\n", Result.Content.GetSize(), FilePath);
- }
- else
- {
- // Print to stdout
- std::string_view Text = Result.AsText();
- std::cout.write(Text.data(), static_cast<std::streamsize>(Text.size()));
- std::cout << std::endl;
- }
-
- return 0;
-}
-
-int
-CmdHead(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 2)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... head <key>\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
-
- S3HeadObjectResult Result = Client.HeadObject(Key);
-
- if (!Result)
- {
- fmt::print(stderr, "HEAD failed: {}\n", Result.Error);
- return 1;
- }
-
- if (Result.Status == HeadObjectResult::NotFound)
- {
- fmt::print("Object '{}' does not exist\n", Key);
- return 1;
- }
-
- fmt::print("Key: {}\n", Result.Info.Key);
- fmt::print("Size: {} bytes\n", Result.Info.Size);
- fmt::print("ETag: {}\n", Result.Info.ETag);
- fmt::print("Last-Modified: {}\n", Result.Info.LastModified);
- return 0;
-}
-
-int
-CmdDelete(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 2)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... delete <key>\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
-
- S3Result Result = Client.DeleteObject(Key);
- if (!Result)
- {
- fmt::print(stderr, "DELETE failed: {}\n", Result.Error);
- return 1;
- }
-
- fmt::print("Deleted '{}'\n", Key);
- return 0;
-}
-
-int
-CmdList(S3Client& Client, const std::vector<std::string>& Positional)
-{
- std::string Prefix;
- if (Positional.size() >= 2)
- {
- Prefix = Positional[1];
- }
-
- S3ListObjectsResult Result = Client.ListObjects(Prefix);
- if (!Result)
- {
- fmt::print(stderr, "LIST failed: {}\n", Result.Error);
- return 1;
- }
-
- fmt::print("{} objects found:\n", Result.Objects.size());
- for (const auto& Obj : Result.Objects)
- {
- fmt::print(" {:>12} {} {}\n", Obj.Size, Obj.LastModified, Obj.Key);
- }
-
- return 0;
-}
-
-int
-CmdMultipartPut(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 3)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... multipart-put <key> <file> [part-size-mb]\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
- const auto& FilePath = Positional[2];
-
- uint64_t PartSize = 8 * 1024 * 1024; // 8 MB default
- if (Positional.size() >= 4)
- {
- PartSize = std::stoull(Positional[3]) * 1024 * 1024;
- }
-
- IoBuffer Content = ReadFileToBuffer(FilePath);
- fmt::print("Multipart uploading '{}' ({} bytes, part size {} MB) to s3://{}/{}\n",
- FilePath,
- Content.GetSize(),
- PartSize / (1024 * 1024),
- Client.BucketName(),
- Key);
-
- S3Result Result = Client.PutObjectMultipart(Key, Content, PartSize);
- if (!Result)
- {
- fmt::print(stderr, "Multipart PUT failed: {}\n", Result.Error);
- return 1;
- }
-
- fmt::print("OK\n");
- return 0;
-}
-
-int
-CmdRoundtrip(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 2)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... roundtrip <key>\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
-
- // Generate test data
- const size_t TestSize = 1024 * 64; // 64 KB
- std::vector<uint8_t> TestData(TestSize);
- for (size_t i = 0; i < TestSize; ++i)
- {
- TestData[i] = static_cast<uint8_t>(i & 0xFF);
- }
-
- IoBuffer UploadContent(IoBuffer::Clone, TestData.data(), TestData.size());
-
- fmt::print("=== Roundtrip test for key '{}' ===\n\n", Key);
-
- // PUT
- fmt::print("[1/4] PUT {} bytes...\n", TestSize);
- S3Result Result = Client.PutObject(Key, UploadContent);
- if (!Result)
- {
- fmt::print(stderr, " FAILED: {}\n", Result.Error);
- return 1;
- }
- fmt::print(" OK\n");
-
- // HEAD
- fmt::print("[2/4] HEAD...\n");
- S3HeadObjectResult HeadResult = Client.HeadObject(Key);
- if (HeadResult.Status != HeadObjectResult::Found)
- {
- fmt::print(stderr, " FAILED: {}\n", !HeadResult ? HeadResult.Error : "not found");
- return 1;
- }
- fmt::print(" OK (size={}, etag={})\n", HeadResult.Info.Size, HeadResult.Info.ETag);
-
- if (HeadResult.Info.Size != TestSize)
- {
- fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, HeadResult.Info.Size);
- return 1;
- }
-
- // GET
- fmt::print("[3/4] GET and verify...\n");
- S3GetObjectResult GetResult = Client.GetObject(Key);
- if (!GetResult)
- {
- fmt::print(stderr, " FAILED: {}\n", GetResult.Error);
- return 1;
- }
-
- if (GetResult.Content.GetSize() != TestSize)
- {
- fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, GetResult.Content.GetSize());
- return 1;
- }
-
- if (memcmp(GetResult.Content.GetData(), TestData.data(), TestSize) != 0)
- {
- fmt::print(stderr, " DATA MISMATCH\n");
- return 1;
- }
- fmt::print(" OK (verified {} bytes)\n", TestSize);
-
- // DELETE
- fmt::print("[4/4] DELETE...\n");
- Result = Client.DeleteObject(Key);
- if (!Result)
- {
- fmt::print(stderr, " FAILED: {}\n", Result.Error);
- return 1;
- }
- fmt::print(" OK\n");
-
- fmt::print("\n=== Roundtrip test PASSED ===\n");
- return 0;
-}
-
-int
-CmdPresign(S3Client& Client, const std::vector<std::string>& Positional)
-{
- if (Positional.size() < 2)
- {
- fmt::print(stderr, "Usage: zens3-testbed ... presign <key> [method] [expires-seconds]\n");
- return 1;
- }
-
- const auto& Key = Positional[1];
-
- std::string Method = "GET";
- if (Positional.size() >= 3)
- {
- Method = Positional[2];
- }
-
- std::chrono::seconds ExpiresIn(3600);
- if (Positional.size() >= 4)
- {
- ExpiresIn = std::chrono::seconds(std::stoul(Positional[3]));
- }
-
- std::string Url;
- if (Method == "PUT")
- {
- Url = Client.GeneratePresignedPutUrl(Key, ExpiresIn);
- }
- else
- {
- Url = Client.GeneratePresignedGetUrl(Key, ExpiresIn);
- }
-
- fmt::print("{}\n", Url);
- return 0;
-}
-
-} // namespace
-
-int
-main(int argc, char* argv[])
-{
- using namespace zen;
-
- logging::InitializeLogging();
-
- cxxopts::Options Options("zens3-testbed", "Test bed for exercising S3 operations via the zens3 module");
-
- // clang-format off
- Options.add_options()
- ("b,bucket", "S3 bucket name", cxxopts::value<std::string>())
- ("r,region", "AWS region", cxxopts::value<std::string>()->default_value("us-east-1"))
- ("e,endpoint", "Custom S3 endpoint URL", cxxopts::value<std::string>())
- ("path-style", "Use path-style addressing (for MinIO, etc.)")
- ("imds", "Use EC2 IMDS for credentials instead of env vars")
- ("imds-endpoint", "Custom IMDS endpoint URL (for testing)", cxxopts::value<std::string>())
- ("timeout", "Request timeout in seconds", cxxopts::value<int>()->default_value("30"))
- ("v,verbose", "Enable verbose logging")
- ("h,help", "Show help")
- ("positional", "Command and arguments", cxxopts::value<std::vector<std::string>>());
- // clang-format on
-
- Options.parse_positional({"positional"});
- Options.positional_help("<command> [args...]");
-
- try
- {
- auto Result = Options.parse(argc, argv);
-
- if (Result.count("help") || !Result.count("positional"))
- {
- fmt::print("{}\n", Options.help());
- fmt::print("Commands:\n");
- fmt::print(" put <key> <file> Upload a local file\n");
- fmt::print(" get <key> [file] Download (to file or stdout)\n");
- fmt::print(" head <key> Show object metadata\n");
- fmt::print(" delete <key> Delete an object\n");
- fmt::print(" list [prefix] List objects\n");
- fmt::print(" multipart-put <key> <file> [part-mb] Multipart upload\n");
- fmt::print(" roundtrip <key> Upload/download/verify/delete\n");
- fmt::print(" presign <key> [method] [expires-sec] Generate pre-signed URL\n");
- fmt::print("\nCredentials via AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars,\n");
- fmt::print("or use --imds to fetch from EC2 Instance Metadata Service.\n");
- return 0;
- }
-
- if (!Result.count("bucket"))
- {
- fmt::print(stderr, "Error: --bucket is required\n");
- return 1;
- }
-
- if (Result.count("verbose"))
- {
- logging::SetLogLevel(logging::Debug);
- }
-
- auto Client = CreateClient(Result);
-
- const auto& Positional = Result["positional"].as<std::vector<std::string>>();
- const auto& Command = Positional[0];
-
- if (Command == "put")
- {
- return CmdPut(Client, Positional);
- }
- else if (Command == "get")
- {
- return CmdGet(Client, Positional);
- }
- else if (Command == "head")
- {
- return CmdHead(Client, Positional);
- }
- else if (Command == "delete")
- {
- return CmdDelete(Client, Positional);
- }
- else if (Command == "list")
- {
- return CmdList(Client, Positional);
- }
- else if (Command == "multipart-put")
- {
- return CmdMultipartPut(Client, Positional);
- }
- else if (Command == "roundtrip")
- {
- return CmdRoundtrip(Client, Positional);
- }
- else if (Command == "presign")
- {
- return CmdPresign(Client, Positional);
- }
- else
- {
- fmt::print(stderr, "Unknown command: '{}'\n", Command);
- return 1;
- }
- }
- catch (const std::exception& Ex)
- {
- fmt::print(stderr, "Error: {}\n", Ex.what());
- return 1;
- }
-}
diff --git a/src/zens3-testbed/xmake.lua b/src/zens3-testbed/xmake.lua
deleted file mode 100644
index 168ab9de9..000000000
--- a/src/zens3-testbed/xmake.lua
+++ /dev/null
@@ -1,8 +0,0 @@
--- Copyright Epic Games, Inc. All Rights Reserved.
-
-target("zens3-testbed")
- set_kind("binary")
- set_group("tools")
- add_files("*.cpp")
- add_deps("zenutil", "zencore")
- add_deps("cxxopts", "fmt")
diff --git a/src/zenserver-test/buildstore-tests.cpp b/src/zenserver-test/buildstore-tests.cpp
index cf9b10896..1f2157993 100644
--- a/src/zenserver-test/buildstore-tests.cpp
+++ b/src/zenserver-test/buildstore-tests.cpp
@@ -148,8 +148,6 @@ TEST_CASE("buildstore.blobs")
}
{
- // Single-range Get
-
ZenServerInstance Instance(TestEnv);
const uint16_t PortNumber =
@@ -158,6 +156,7 @@ TEST_CASE("buildstore.blobs")
HttpClient Client(Instance.GetBaseUri() + "/builds/");
+ // Single-range Get
{
const IoHash& RawHash = CompressedBlobsHashes.front();
uint64_t BlobSize = CompressedBlobsSizes.front();
@@ -183,20 +182,63 @@ TEST_CASE("buildstore.blobs")
MemoryView RangeView = Payload.GetView();
CHECK(ActualRange.EqualBytes(RangeView));
}
- }
- {
- // Single-range Post
+ {
+ // GET blob not found
+ IoHash FakeHash = IoHash::HashBuffer("nonexistent", 11);
+ HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, FakeHash));
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::NotFound);
+ }
- ZenServerInstance Instance(TestEnv);
+ {
+ // GET with out-of-bounds range
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ uint64_t BlobSize = CompressedBlobsSizes.front();
- const uint16_t PortNumber =
- Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath));
- CHECK(PortNumber != 0);
+ HttpClient::KeyValueMap Headers;
+ Headers.Entries.insert({"Range", fmt::format("bytes={}-{}", BlobSize + 100, BlobSize + 200)});
- HttpClient Client(Instance.GetBaseUri() + "/builds/");
+ HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), Headers);
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::RangeNotSatisfiable);
+ }
{
+ // GET with multi-range header (uses Download for multipart boundary parsing)
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ uint64_t BlobSize = CompressedBlobsSizes.front();
+
+ uint64_t Range1Start = 0;
+ uint64_t Range1End = BlobSize / 4 - 1;
+ uint64_t Range2Start = BlobSize / 2;
+ uint64_t Range2End = BlobSize / 2 + BlobSize / 4 - 1;
+
+ HttpClient::KeyValueMap Headers;
+ Headers.Entries.insert({"Range", fmt::format("bytes={}-{},{}-{}", Range1Start, Range1End, Range2Start, Range2End)});
+
+ HttpClient::Response Result =
+ Client.Download(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), SystemRootPath, Headers);
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE_EQ(Result.Ranges.size(), 2);
+
+ HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ HttpClient::Accept(ZenContentType::kCompressedBinary));
+ REQUIRE(FullBlobResult);
+
+ uint64_t Range1Len = Range1End - Range1Start + 1;
+ uint64_t Range2Len = Range2End - Range2Start + 1;
+
+ MemoryView ExpectedRange1 = FullBlobResult.ResponsePayload.GetView().Mid(Range1Start, Range1Len);
+ MemoryView ExpectedRange2 = FullBlobResult.ResponsePayload.GetView().Mid(Range2Start, Range2Len);
+
+ MemoryView ActualRange1 = Result.ResponsePayload.GetView().Mid(Result.Ranges[0].OffsetInPayload, Range1Len);
+ MemoryView ActualRange2 = Result.ResponsePayload.GetView().Mid(Result.Ranges[1].OffsetInPayload, Range2Len);
+
+ CHECK(ExpectedRange1.EqualBytes(ActualRange1));
+ CHECK(ExpectedRange2.EqualBytes(ActualRange2));
+ }
+
+ // Single-range Post
+ {
uint64_t RangeSizeSum = 0;
const IoHash& RawHash = CompressedBlobsHashes.front();
@@ -259,19 +301,96 @@ TEST_CASE("buildstore.blobs")
Offset += Range.second;
}
}
- }
- {
- // Multi-range
+ {
+ // POST with wrong accept type
+ const IoHash& RawHash = CompressedBlobsHashes.front();
- ZenServerInstance Instance(TestEnv);
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ Writer.BeginObject();
+ Writer.AddInteger("offset"sv, uint64_t(0));
+ Writer.AddInteger("length"sv, uint64_t(10));
+ Writer.EndObject();
+ Writer.EndArray();
- const uint16_t PortNumber =
- Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath));
- CHECK(PortNumber != 0);
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kBinary));
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest);
+ }
- HttpClient Client(Instance.GetBaseUri() + "/builds/");
+ {
+ // POST with missing payload
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest);
+ }
+
+ {
+ // POST with empty ranges array
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ Writer.EndArray();
+
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest);
+ }
+
+ {
+ // POST with range count exceeding maximum
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ for (uint32_t I = 0; I < 257; I++)
+ {
+ Writer.BeginObject();
+ Writer.AddInteger("offset"sv, uint64_t(0));
+ Writer.AddInteger("length"sv, uint64_t(1));
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::BadRequest);
+ }
+
+ {
+ // POST with out-of-bounds range returns length=0
+ const IoHash& RawHash = CompressedBlobsHashes.front();
+ uint64_t BlobSize = CompressedBlobsSizes.front();
+
+ CbObjectWriter Writer;
+ Writer.BeginArray("ranges"sv);
+ Writer.BeginObject();
+ Writer.AddInteger("offset"sv, BlobSize + 100);
+ Writer.AddInteger("length"sv, uint64_t(50));
+ Writer.EndObject();
+ Writer.EndArray();
+
+ HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash),
+ Writer.Save(),
+ HttpClient::Accept(ZenContentType::kCbPackage));
+ REQUIRE(Result);
+ CbPackage ResponsePackage = ParsePackageMessage(Result.ResponsePayload);
+ CbObjectView ResponseObject = ResponsePackage.GetObject();
+ CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView();
+ REQUIRE_EQ(RangeArray.Num(), uint64_t(1));
+ CbObjectView Range = (*begin(RangeArray)).AsObjectView();
+ CHECK_EQ(Range["offset"sv].AsUInt64(), BlobSize + 100);
+ CHECK_EQ(Range["length"sv].AsUInt64(), uint64_t(0));
+ }
+
+ // Multi-range
{
uint64_t RangeSizeSum = 0;
diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp
index 14748e214..e54e7060d 100644
--- a/src/zenserver-test/cache-tests.cpp
+++ b/src/zenserver-test/cache-tests.cpp
@@ -9,6 +9,7 @@
# include <zencore/compactbinarypackage.h>
# include <zencore/compress.h>
# include <zencore/fmtutils.h>
+# include <zenhttp/localrefpolicy.h>
# include <zenhttp/packageformat.h>
# include <zenstore/cache/cachepolicy.h>
# include <zencore/filesystem.h>
@@ -25,6 +26,13 @@ namespace zen::tests {
TEST_SUITE_BEGIN("server.cache");
+/// Permissive policy that allows any path, for use in tests that exercise local ref
+/// functionality but are not testing path validation.
+struct PermissiveLocalRefPolicy : public ILocalRefPolicy
+{
+ void ValidatePath(const std::filesystem::path&) const override {}
+};
+
TEST_CASE("zcache.basic")
{
using namespace std::literals;
@@ -164,143 +172,85 @@ TEST_CASE("zcache.cbpackage")
return true;
};
- SUBCASE("PUT/GET returns correct package")
- {
- std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir();
+ std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir();
- ZenServerInstance Instance1(TestEnv);
- Instance1.SetDataDir(TestDir);
- const uint16_t PortNumber = Instance1.SpawnServerAndWaitUntilReady();
- const std::string BaseUri = fmt::format("http://localhost:{}/z$", PortNumber);
+ ZenServerInstance RemoteInstance(TestEnv);
+ RemoteInstance.SetDataDir(RemoteDataDir);
+ const uint16_t RemotePortNumber = RemoteInstance.SpawnServerAndWaitUntilReady();
- HttpClient Http{BaseUri};
+ ZenServerInstance LocalInstance(TestEnv);
+ LocalInstance.SetDataDir(LocalDataDir);
+ LocalInstance.SpawnServer(TestEnv.GetNewPortNumber(),
+ fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber));
+ const uint16_t LocalPortNumber = LocalInstance.WaitUntilReady();
+ CHECK_MESSAGE(LocalPortNumber != 0, LocalInstance.GetLogOutput());
- const std::string_view Bucket = "mosdef"sv;
- zen::IoHash Key;
- zen::CbPackage ExpectedPackage = CreateTestPackage(Key);
+ const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
+ const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber);
- // PUT
- {
- zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
- HttpClient::Response Result = Http.Put(fmt::format("/{}/{}", Bucket, Key), Body);
- CHECK(Result.StatusCode == HttpResponseCode::Created);
- }
+ HttpClient LocalHttp{LocalBaseUri};
+ HttpClient RemoteHttp{RemoteBaseUri};
- // GET
- {
- HttpClient::Response Result = Http.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}});
- CHECK(Result.StatusCode == HttpResponseCode::OK);
-
- zen::CbPackage Package;
- const bool Ok = Package.TryLoad(Result.ResponsePayload);
- CHECK(Ok);
- CHECK(IsEqual(Package, ExpectedPackage));
- }
- }
+ const std::string_view Bucket = "mosdef"sv;
- SUBCASE("PUT propagates upstream")
+ // Phase 1: PUT/GET returns correct package (via local)
{
- // Setup local and remote server
- std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir();
- std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir();
-
- ZenServerInstance RemoteInstance(TestEnv);
- RemoteInstance.SetDataDir(RemoteDataDir);
- const uint16_t RemotePortNumber = RemoteInstance.SpawnServerAndWaitUntilReady();
-
- ZenServerInstance LocalInstance(TestEnv);
- LocalInstance.SetDataDir(LocalDataDir);
- LocalInstance.SpawnServer(TestEnv.GetNewPortNumber(),
- fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber));
- const uint16_t LocalPortNumber = LocalInstance.WaitUntilReady();
- CHECK_MESSAGE(LocalPortNumber != 0, LocalInstance.GetLogOutput());
-
- const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
- const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber);
-
- const std::string_view Bucket = "mosdef"sv;
- zen::IoHash Key;
- zen::CbPackage ExpectedPackage = CreateTestPackage(Key);
-
- HttpClient LocalHttp{LocalBaseUri};
- HttpClient RemoteHttp{RemoteBaseUri};
-
- // Store the cache record package in the local instance
- {
- zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
- HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key), Body);
-
- CHECK(Result.StatusCode == HttpResponseCode::Created);
- }
+ zen::IoHash Key1;
+ zen::CbPackage ExpectedPackage = CreateTestPackage(Key1);
- // The cache record can be retrieved as a package from the local instance
- {
- HttpClient::Response Result = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}});
- CHECK(Result.StatusCode == HttpResponseCode::OK);
+ zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
+ HttpClient::Response PutResult = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key1), Body);
+ CHECK(PutResult.StatusCode == HttpResponseCode::Created);
- zen::CbPackage Package;
- const bool Ok = Package.TryLoad(Result.ResponsePayload);
- CHECK(Ok);
- CHECK(IsEqual(Package, ExpectedPackage));
- }
+ HttpClient::Response GetResult = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key1), {{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(GetResult.StatusCode == HttpResponseCode::OK);
- // The cache record can be retrieved as a package from the remote instance
- {
- HttpClient::Response Result = RemoteHttp.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}});
- CHECK(Result.StatusCode == HttpResponseCode::OK);
-
- zen::CbPackage Package;
- const bool Ok = Package.TryLoad(Result.ResponsePayload);
- CHECK(Ok);
- CHECK(IsEqual(Package, ExpectedPackage));
- }
+ zen::CbPackage Package;
+ const bool Ok = Package.TryLoad(GetResult.ResponsePayload);
+ CHECK(Ok);
+ CHECK(IsEqual(Package, ExpectedPackage));
}
- SUBCASE("GET finds upstream when missing in local")
+ // Phase 2: PUT propagates upstream
{
- // Setup local and remote server
- std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir();
- std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir();
+ zen::IoHash Key2;
+ zen::CbPackage ExpectedPackage = CreateTestPackage(Key2);
- ZenServerInstance RemoteInstance(TestEnv);
- RemoteInstance.SetDataDir(RemoteDataDir);
- const uint16_t RemotePortNumber = RemoteInstance.SpawnServerAndWaitUntilReady();
+ zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
+ HttpClient::Response PutResult = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key2), Body);
+ CHECK(PutResult.StatusCode == HttpResponseCode::Created);
- ZenServerInstance LocalInstance(TestEnv);
- LocalInstance.SetDataDir(LocalDataDir);
- LocalInstance.SpawnServer(TestEnv.GetNewPortNumber(),
- fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber));
- const uint16_t LocalPortNumber = LocalInstance.WaitUntilReady();
- CHECK_MESSAGE(LocalPortNumber != 0, LocalInstance.GetLogOutput());
+ HttpClient::Response LocalGetResult = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key2), {{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(LocalGetResult.StatusCode == HttpResponseCode::OK);
- const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
- const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber);
+ zen::CbPackage LocalPackage;
+ CHECK(LocalPackage.TryLoad(LocalGetResult.ResponsePayload));
+ CHECK(IsEqual(LocalPackage, ExpectedPackage));
- HttpClient LocalHttp{LocalBaseUri};
- HttpClient RemoteHttp{RemoteBaseUri};
+ HttpClient::Response RemoteGetResult = RemoteHttp.Get(fmt::format("/{}/{}", Bucket, Key2), {{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(RemoteGetResult.StatusCode == HttpResponseCode::OK);
- const std::string_view Bucket = "mosdef"sv;
- zen::IoHash Key;
- zen::CbPackage ExpectedPackage = CreateTestPackage(Key);
+ zen::CbPackage RemotePackage;
+ CHECK(RemotePackage.TryLoad(RemoteGetResult.ResponsePayload));
+ CHECK(IsEqual(RemotePackage, ExpectedPackage));
+ }
- // Store the cache record package in upstream cache
- {
- zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
- HttpClient::Response Result = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key), Body);
+ // Phase 3: GET finds upstream when missing in local
+ {
+ zen::IoHash Key3;
+ zen::CbPackage ExpectedPackage = CreateTestPackage(Key3);
- CHECK(Result.StatusCode == HttpResponseCode::Created);
- }
+ zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
+ HttpClient::Response PutResult = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key3), Body);
+ CHECK(PutResult.StatusCode == HttpResponseCode::Created);
- // The cache record can be retrieved as a package from the local cache
- {
- HttpClient::Response Result = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}});
- CHECK(Result.StatusCode == HttpResponseCode::OK);
+ HttpClient::Response GetResult = LocalHttp.Get(fmt::format("/{}/{}", Bucket, Key3), {{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(GetResult.StatusCode == HttpResponseCode::OK);
- zen::CbPackage Package;
- const bool Ok = Package.TryLoad(Result.ResponsePayload);
- CHECK(Ok);
- CHECK(IsEqual(Package, ExpectedPackage));
- }
+ zen::CbPackage Package;
+ CHECK(Package.TryLoad(GetResult.ResponsePayload));
+ CHECK(IsEqual(Package, ExpectedPackage));
}
}
@@ -340,25 +290,25 @@ TEST_CASE("zcache.policy")
return Package;
};
- SUBCASE("query - 'local' does not query upstream (binary)")
- {
- ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance UpstreamInst(TestEnv);
- UpstreamCfg.Spawn(UpstreamInst);
- const uint16_t UpstreamPort = UpstreamCfg.Port;
+ ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
+ ZenServerInstance UpstreamInst(TestEnv);
+ UpstreamCfg.Spawn(UpstreamInst);
- ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamPort);
- ZenServerInstance LocalInst(TestEnv);
- LocalCfg.Spawn(LocalInst);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port);
+ ZenServerInstance LocalInst(TestEnv);
+ LocalCfg.Spawn(LocalInst);
+
+ HttpClient LocalHttp{LocalCfg.BaseUri};
+ HttpClient RemoteHttp{UpstreamCfg.BaseUri};
- const std::string_view Bucket = "legacy"sv;
+ // query - 'local' does not query upstream (binary)
+ // Uses size 1024 for unique key
+ {
+ const auto Bucket = "legacy"sv;
zen::IoHash Key;
IoBuffer BinaryValue = GenerateData(1024, Key);
- HttpClient LocalHttp{LocalCfg.BaseUri};
- HttpClient RemoteHttp{UpstreamCfg.BaseUri};
-
{
HttpClient::Response Result = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key), BinaryValue);
CHECK(Result.StatusCode == HttpResponseCode::Created);
@@ -377,26 +327,14 @@ TEST_CASE("zcache.policy")
}
}
- SUBCASE("store - 'local' does not store upstream (binary)")
+ // store - 'local' does not store upstream (binary)
+ // Uses size 2048 for unique key
{
- ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance UpstreamInst(TestEnv);
- UpstreamCfg.Spawn(UpstreamInst);
- const uint16_t UpstreamPort = UpstreamCfg.Port;
-
- ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamPort);
- ZenServerInstance LocalInst(TestEnv);
- LocalCfg.Spawn(LocalInst);
-
const auto Bucket = "legacy"sv;
zen::IoHash Key;
- IoBuffer BinaryValue = GenerateData(1024, Key);
+ IoBuffer BinaryValue = GenerateData(2048, Key);
- HttpClient LocalHttp{LocalCfg.BaseUri};
- HttpClient RemoteHttp{UpstreamCfg.BaseUri};
-
- // Store binary cache value locally
{
HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,StoreLocal", Bucket, Key),
BinaryValue,
@@ -415,25 +353,14 @@ TEST_CASE("zcache.policy")
}
}
- SUBCASE("store - 'local/remote' stores local and upstream (binary)")
+ // store - 'local/remote' stores local and upstream (binary)
+ // Uses size 4096 for unique key
{
- ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance UpstreamInst(TestEnv);
- UpstreamCfg.Spawn(UpstreamInst);
-
- ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port);
- ZenServerInstance LocalInst(TestEnv);
- LocalCfg.Spawn(LocalInst);
-
const auto Bucket = "legacy"sv;
zen::IoHash Key;
- IoBuffer BinaryValue = GenerateData(1024, Key);
-
- HttpClient LocalHttp{LocalCfg.BaseUri};
- HttpClient RemoteHttp{UpstreamCfg.BaseUri};
+ IoBuffer BinaryValue = GenerateData(4096, Key);
- // Store binary cache value locally and upstream
{
HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,Store", Bucket, Key),
BinaryValue,
@@ -452,27 +379,16 @@ TEST_CASE("zcache.policy")
}
}
- SUBCASE("query - 'local' does not query upstream (cbpackage)")
+ // query - 'local' does not query upstream (cbpackage)
+ // Uses bucket "policy4" to isolate from other cbpackage scenarios (deterministic key)
{
- ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance UpstreamInst(TestEnv);
- UpstreamCfg.Spawn(UpstreamInst);
-
- ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port);
- ZenServerInstance LocalInst(TestEnv);
- LocalCfg.Spawn(LocalInst);
-
- const auto Bucket = "legacy"sv;
+ const auto Bucket = "policy4"sv;
zen::IoHash Key;
zen::IoHash PayloadId;
zen::CbPackage Package = GeneratePackage(Key, PayloadId);
IoBuffer Buf = SerializeToBuffer(Package);
- HttpClient LocalHttp{LocalCfg.BaseUri};
- HttpClient RemoteHttp{UpstreamCfg.BaseUri};
-
- // Store package upstream
{
HttpClient::Response Result = RemoteHttp.Put(fmt::format("/{}/{}", Bucket, Key), Buf);
CHECK(Result.StatusCode == HttpResponseCode::Created);
@@ -491,27 +407,16 @@ TEST_CASE("zcache.policy")
}
}
- SUBCASE("store - 'local' does not store upstream (cbpackage)")
+ // store - 'local' does not store upstream (cbpackage)
+ // Uses bucket "policy5" to isolate from other cbpackage scenarios (deterministic key)
{
- ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance UpstreamInst(TestEnv);
- UpstreamCfg.Spawn(UpstreamInst);
-
- ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port);
- ZenServerInstance LocalInst(TestEnv);
- LocalCfg.Spawn(LocalInst);
-
- const auto Bucket = "legacy"sv;
+ const auto Bucket = "policy5"sv;
zen::IoHash Key;
zen::IoHash PayloadId;
zen::CbPackage Package = GeneratePackage(Key, PayloadId);
IoBuffer Buf = SerializeToBuffer(Package);
- HttpClient LocalHttp{LocalCfg.BaseUri};
- HttpClient RemoteHttp{UpstreamCfg.BaseUri};
-
- // Store package locally
{
HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,StoreLocal", Bucket, Key), Buf);
CHECK(Result.StatusCode == HttpResponseCode::Created);
@@ -528,27 +433,16 @@ TEST_CASE("zcache.policy")
}
}
- SUBCASE("store - 'local/remote' stores local and upstream (cbpackage)")
+ // store - 'local/remote' stores local and upstream (cbpackage)
+ // Uses bucket "policy6" to isolate from other cbpackage scenarios (deterministic key)
{
- ZenConfig UpstreamCfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance UpstreamInst(TestEnv);
- UpstreamCfg.Spawn(UpstreamInst);
-
- ZenConfig LocalCfg = ZenConfig::NewWithUpstream(TestEnv.GetNewPortNumber(), UpstreamCfg.Port);
- ZenServerInstance LocalInst(TestEnv);
- LocalCfg.Spawn(LocalInst);
-
- const auto Bucket = "legacy"sv;
+ const auto Bucket = "policy6"sv;
zen::IoHash Key;
zen::IoHash PayloadId;
zen::CbPackage Package = GeneratePackage(Key, PayloadId);
IoBuffer Buf = SerializeToBuffer(Package);
- HttpClient LocalHttp{LocalCfg.BaseUri};
- HttpClient RemoteHttp{UpstreamCfg.BaseUri};
-
- // Store package locally and upstream
{
HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}?Policy=Query,Store", Bucket, Key), Buf);
CHECK(Result.StatusCode == HttpResponseCode::Created);
@@ -565,78 +459,62 @@ TEST_CASE("zcache.policy")
}
}
- SUBCASE("skip - 'data' returns cache record without attachments/empty payload")
+ // skip - 'data' returns cache record without attachments/empty payload
+ // Uses bucket "skiptest7" to isolate from other cbpackage scenarios
{
- ZenConfig Cfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance Instance(TestEnv);
- Cfg.Spawn(Instance);
-
- const auto Bucket = "test"sv;
+ const auto Bucket = "skiptest7"sv;
zen::IoHash Key;
zen::IoHash PayloadId;
zen::CbPackage Package = GeneratePackage(Key, PayloadId);
IoBuffer Buf = SerializeToBuffer(Package);
- HttpClient Http{Cfg.BaseUri};
-
- // Store package
{
- HttpClient::Response Result = Http.Put(fmt::format("/{}/{}", Bucket, Key), Buf);
+ HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key), Buf);
CHECK(Result.StatusCode == HttpResponseCode::Created);
}
- // Get package
{
HttpClient::Response Result =
- Http.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}});
+ LocalHttp.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cbpkg"}});
CHECK(Result);
CbPackage ResponsePackage;
CHECK(ResponsePackage.TryLoad(Result.ResponsePayload));
CHECK(ResponsePackage.GetAttachments().size() == 0);
}
- // Get record
{
HttpClient::Response Result =
- Http.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cb"}});
+ LocalHttp.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/x-ue-cb"}});
CHECK(Result);
CbObject ResponseObject = zen::LoadCompactBinaryObject(Result.ResponsePayload);
CHECK(ResponseObject);
}
- // Get payload
{
- HttpClient::Response Result =
- Http.Get(fmt::format("/{}/{}/{}?Policy=Default,SkipData", Bucket, Key, PayloadId), {{"Accept", "application/x-ue-comp"}});
+ HttpClient::Response Result = LocalHttp.Get(fmt::format("/{}/{}/{}?Policy=Default,SkipData", Bucket, Key, PayloadId),
+ {{"Accept", "application/x-ue-comp"}});
CHECK(Result);
CHECK(Result.ResponsePayload.GetSize() == 0);
}
}
- SUBCASE("skip - 'data' returns empty binary value")
+ // skip - 'data' returns empty binary value
+ // Uses size 8192 for unique key (avoids collision with size 1024/2048/4096 above)
{
- ZenConfig Cfg = ZenConfig::New(TestEnv.GetNewPortNumber());
- ZenServerInstance Instance(TestEnv);
- Cfg.Spawn(Instance);
-
- const auto Bucket = "test"sv;
+ const auto Bucket = "skiptest8"sv;
zen::IoHash Key;
- IoBuffer BinaryValue = GenerateData(1024, Key);
-
- HttpClient Http{Cfg.BaseUri};
+ IoBuffer BinaryValue = GenerateData(8192, Key);
- // Store binary cache value
{
- HttpClient::Response Result = Http.Put(fmt::format("/{}/{}", Bucket, Key), BinaryValue);
+ HttpClient::Response Result = LocalHttp.Put(fmt::format("/{}/{}", Bucket, Key), BinaryValue);
CHECK(Result.StatusCode == HttpResponseCode::Created);
}
- // Get package
{
HttpClient::Response Result =
- Http.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/octet-stream"}});
+ LocalHttp.Get(fmt::format("/{}/{}?Policy=Default,SkipData", Bucket, Key), {{"Accept", "application/octet-stream"}});
CHECK(Result);
CHECK(Result.ResponsePayload.GetSize() == 0);
}
@@ -743,7 +621,11 @@ TEST_CASE("zcache.rpc")
if (Result.StatusCode == HttpResponseCode::OK)
{
- CbPackage Response = ParsePackageMessage(Result.ResponsePayload);
+ ParseFlags PFlags = EnumHasAllFlags(AcceptOptions, RpcAcceptOptions::kAllowLocalReferences) ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr;
+ CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy);
CHECK(!Response.IsNull());
OutResult.Response = std::move(Response);
CHECK(OutResult.Result.Parse(OutResult.Response));
@@ -1745,8 +1627,13 @@ TEST_CASE("zcache.rpc.partialchunks")
CHECK(Result.StatusCode == HttpResponseCode::OK);
- CbPackage Response = ParsePackageMessage(Result.ResponsePayload);
- bool Loaded = !Response.IsNull();
+ ParseFlags PFlags = EnumHasAllFlags(Options.AcceptOptions, RpcAcceptOptions::kAllowLocalReferences)
+ ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ const ILocalRefPolicy* PPolicy = EnumHasAllFlags(PFlags, ParseFlags::kAllowLocalReferences) ? &AllowAllPolicy : nullptr;
+ CbPackage Response = ParsePackageMessage(Result.ResponsePayload, {}, PFlags, PPolicy);
+ bool Loaded = !Response.IsNull();
CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load.");
cacherequests::GetCacheChunksResult GetCacheChunksResult;
CHECK(GetCacheChunksResult.Parse(Response));
diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp
index 021052a3b..a4755adec 100644
--- a/src/zenserver-test/compute-tests.cpp
+++ b/src/zenserver-test/compute-tests.cpp
@@ -21,6 +21,7 @@
# include <zenhttp/httpserver.h>
# include <zenhttp/websocket.h>
# include <zencompute/computeservice.h>
+# include <zencore/fmtutils.h>
# include <zenstore/zenstore.h>
# include <zenutil/zenserverprocess.h>
@@ -36,6 +37,8 @@ using namespace std::literals;
static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969";
static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313";
static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888";
+static constexpr std::string_view kFailVersion = "fa11fa11-fa11-fa11-fa11-fa11fa11fa11";
+static constexpr std::string_view kCrashVersion = "c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5";
// In-memory implementation of ChunkResolver for test use.
// Stores compressed data keyed by decompressed content hash.
@@ -104,6 +107,16 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env)
<< "Sleep"sv;
WorkerWriter << "version"sv << Guid::FromString(kSleepVersion);
WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Fail"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kFailVersion);
+ WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Crash"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kCrashVersion);
+ WorkerWriter.EndObject();
WorkerWriter.EndArray();
CbPackage WorkerPackage;
@@ -115,7 +128,7 @@ RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env)
const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString());
HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage));
REQUIRE_MESSAGE(RegisterResp,
- fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText()));
+ fmt::format("Worker registration failed: status={}, body={}", RegisterResp.StatusCode, RegisterResp.ToText()));
return WorkerId;
}
@@ -220,6 +233,83 @@ BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemor
return ActionWriter.Save();
}
+// Build a Fail action CbPackage. The worker exits with the given exit code.
+static CbPackage
+BuildFailActionPackage(int ExitCode)
+{
+ // The Fail function throws before reading inputs, but the action structure
+ // still requires a valid input attachment for the runner to manifest.
+ std::string_view Dummy = "x"sv;
+
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Dummy.size();
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Fail"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kFailVersion);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+ ActionWriter.BeginObject("Constants"sv);
+ ActionWriter << "ExitCode"sv << static_cast<uint64_t>(ExitCode);
+ ActionWriter.EndObject();
+
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionWriter.Save());
+ ActionPackage.AddAttachment(InputAttachment);
+
+ return ActionPackage;
+}
+
+// Build a Crash action CbPackage. The worker process crashes hard.
+// Mode: "abort" (default) or "nullptr" (null pointer dereference).
+static CbPackage
+BuildCrashActionPackage(std::string_view Mode = "abort"sv)
+{
+ std::string_view Dummy = "x"sv;
+
+ CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Dummy.data(), Dummy.size()),
+ OodleCompressor::Selkie,
+ OodleCompressionLevel::HyperFast4);
+
+ const IoHash InputRawHash = InputCompressed.DecodeRawHash();
+ const uint64_t InputRawSize = Dummy.size();
+
+ CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash);
+
+ CbObjectWriter ActionWriter;
+ ActionWriter << "Function"sv
+ << "Crash"sv;
+ ActionWriter << "FunctionVersion"sv << Guid::FromString(kCrashVersion);
+ ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion);
+ ActionWriter.BeginObject("Inputs"sv);
+ ActionWriter.BeginObject("Source"sv);
+ ActionWriter.AddAttachment("RawHash"sv, InputAttachment);
+ ActionWriter << "RawSize"sv << InputRawSize;
+ ActionWriter.EndObject();
+ ActionWriter.EndObject();
+ ActionWriter.BeginObject("Constants"sv);
+ ActionWriter << "Mode"sv << Mode;
+ ActionWriter.EndObject();
+
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionWriter.Save());
+ ActionPackage.AddAttachment(InputAttachment);
+
+ return ActionPackage;
+}
+
static HttpClient::Response
PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000)
{
@@ -267,6 +357,41 @@ PollForLsnInCompleted(HttpClient& Client, const std::string& CompletedUrl, int L
return false;
}
+static void
+WaitForActionRunning(zen::compute::ComputeServiceSession& Session, uint64_t TimeoutMs = 10'000)
+{
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < TimeoutMs)
+ {
+ if (Session.GetActionCounts().Running > 0)
+ {
+ return;
+ }
+ Sleep(50);
+ }
+ FAIL("Timed out waiting for action to reach Running state");
+}
+
+static void
+WaitForAnyActionRunningHttp(HttpClient& Client, uint64_t TimeoutMs = 10'000)
+{
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < TimeoutMs)
+ {
+ HttpClient::Response Resp = Client.Get("/jobs/running"sv);
+ if (Resp)
+ {
+ CbObject Obj = Resp.AsObject();
+ if (Obj["running"sv].AsArrayView().Num() > 0)
+ {
+ return;
+ }
+ }
+ Sleep(50);
+ }
+ FAIL("Timed out waiting for any action to reach Running state");
+}
+
static std::string
GetRot13Output(const CbPackage& ResultPackage)
{
@@ -340,8 +465,9 @@ public:
}
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override
{
+ ZEN_UNUSED(RelativeUri);
m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
@@ -469,6 +595,16 @@ BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver)
<< "Sleep"sv;
WorkerWriter << "version"sv << Guid::FromString(kSleepVersion);
WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Fail"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kFailVersion);
+ WorkerWriter.EndObject();
+ WorkerWriter.BeginObject();
+ WorkerWriter << "name"sv
+ << "Crash"sv;
+ WorkerWriter << "version"sv << Guid::FromString(kCrashVersion);
+ WorkerWriter.EndObject();
WorkerWriter.EndArray();
CbPackage WorkerPackage;
@@ -526,7 +662,7 @@ TEST_CASE("function.rot13")
// Submit action via legacy /jobs/{worker} endpoint
const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission");
@@ -536,7 +672,7 @@ TEST_CASE("function.rot13")
HttpClient::Response ResultResp = PollForResult(Client, ResultUrl);
REQUIRE_MESSAGE(
ResultResp.StatusCode == HttpResponseCode::OK,
- fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput()));
+ fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput()));
// Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
CbPackage ResultPackage = ResultResp.AsPackage();
@@ -562,7 +698,7 @@ TEST_CASE("function.workers")
const IoHash WorkerId = RegisterWorker(Client, TestEnv);
- // GET /workers — the registered worker should appear in the listing
+ // GET /workers - the registered worker should appear in the listing
HttpClient::Response ListResp = Client.Get("/workers"sv);
REQUIRE_MESSAGE(ListResp, "Failed to list workers after registration");
@@ -578,10 +714,10 @@ TEST_CASE("function.workers")
REQUIRE_MESSAGE(WorkerFound, fmt::format("Worker {} not found in worker listing", WorkerId.ToHexString()));
- // GET /workers/{worker} — descriptor should match what was registered
+ // GET /workers/{worker} - descriptor should match what was registered
const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString());
HttpClient::Response DescResp = Client.Get(WorkerUrl);
- REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode)));
+ REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", DescResp.StatusCode));
CbObject Desc = DescResp.AsObject();
CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion));
@@ -607,7 +743,7 @@ TEST_CASE("function.workers")
CHECK_MESSAGE(Rot13Found, "Rot13 function not found in worker descriptor");
CHECK_MESSAGE(SleepFound, "Sleep function not found in worker descriptor");
- // GET /workers/{unknown} — should return 404
+ // GET /workers/{unknown} - should return 404
const std::string UnknownUrl = fmt::format("/workers/{}", IoHash::Zero.ToHexString());
HttpClient::Response NotFoundResp = Client.Get(UnknownUrl);
CHECK_EQ(NotFoundResp.StatusCode, HttpResponseCode::NotFound);
@@ -627,7 +763,7 @@ TEST_CASE("function.queues.lifecycle")
// Create a queue
HttpClient::Response CreateResp = Client.Post("/queues"sv);
- REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText()));
const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation");
@@ -651,8 +787,7 @@ TEST_CASE("function.queues.lifecycle")
// Submit action via queue-scoped endpoint
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
- REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission");
@@ -668,9 +803,8 @@ TEST_CASE("function.queues.lifecycle")
// Retrieve result via queue-scoped /jobs/{lsn} endpoint
const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
HttpClient::Response ResultResp = Client.Get(ResultUrl);
- REQUIRE_MESSAGE(
- ResultResp.StatusCode == HttpResponseCode::OK,
- fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", ResultResp.StatusCode, Instance.GetLogOutput()));
// Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
CbPackage ResultPackage = ResultResp.AsPackage();
@@ -712,13 +846,13 @@ TEST_CASE("function.queues.cancel")
// Submit a job
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
// Cancel the queue
const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
HttpClient::Response CancelResp = Client.Delete(QueueUrl);
REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
- fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+ fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText()));
// Verify queue status shows cancelled
HttpClient::Response StatusResp = Client.Get(QueueUrl);
@@ -740,10 +874,10 @@ TEST_CASE("function.queues.remote")
const IoHash WorkerId = RegisterWorker(Client, TestEnv);
- // Create a remote queue — response includes both an integer queue_id and an OID queue_token
+ // Create a remote queue - response includes both an integer queue_id and an OID queue_token
HttpClient::Response CreateResp = Client.Post("/queues/remote"sv);
REQUIRE_MESSAGE(CreateResp,
- fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+ fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText()));
CbObject CreateObj = CreateResp.AsObject();
const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString());
@@ -753,7 +887,7 @@ TEST_CASE("function.queues.remote")
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv));
REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ fmt::format("Remote queue job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission");
@@ -769,7 +903,7 @@ TEST_CASE("function.queues.remote")
HttpClient::Response ResultResp = Client.Get(ResultUrl);
REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK,
fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}",
- int(ResultResp.StatusCode),
+ ResultResp.StatusCode,
Instance.GetLogOutput()));
// Verify result: Rot13("Hello World") == "Uryyb Jbeyq"
@@ -801,20 +935,19 @@ TEST_CASE("function.queues.cancel_running")
// Submit a Sleep job long enough that it will still be running when we cancel
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission");
// Wait for the worker process to start executing before cancelling
- Sleep(1'000);
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
+ WaitForAnyActionRunningHttp(Client);
// Cancel the queue, which should interrupt the running Sleep job
- const std::string QueueUrl = fmt::format("/queues/{}", QueueId);
HttpClient::Response CancelResp = Client.Delete(QueueUrl);
REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
- fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+ fmt::format("Queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText()));
// The cancelled job should appear in the /completed endpoint once the process exits
const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
@@ -849,7 +982,7 @@ TEST_CASE("function.queues.remote_cancel")
// Create a remote queue to obtain an OID token for token-addressed cancellation
HttpClient::Response CreateResp = Client.Post("/queues/remote"sv);
REQUIRE_MESSAGE(CreateResp,
- fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText()));
+ fmt::format("Remote queue creation failed: status={}, body={}", CreateResp.StatusCode, CreateResp.ToText()));
const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString());
REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation");
@@ -857,20 +990,19 @@ TEST_CASE("function.queues.remote_cancel")
// Submit a long-running Sleep job via the token-addressed endpoint
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(SubmitResp,
- fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText()));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission");
// Wait for the worker process to start executing before cancelling
- Sleep(1'000);
+ const std::string QueueUrl = fmt::format("/queues/{}", QueueToken);
+ WaitForAnyActionRunningHttp(Client);
// Cancel the queue via its OID token
- const std::string QueueUrl = fmt::format("/queues/{}", QueueToken);
HttpClient::Response CancelResp = Client.Delete(QueueUrl);
REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent,
- fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText()));
+ fmt::format("Remote queue cancellation failed: status={}, body={}", CancelResp.StatusCode, CancelResp.ToText()));
// The cancelled job should appear in the token-addressed /completed endpoint
const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken);
@@ -910,13 +1042,13 @@ TEST_CASE("function.queues.drain")
// Submit a long-running job so we can verify it completes even after drain
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000));
- REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode)));
+ REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", Submit1.StatusCode));
const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32();
// Drain the queue
const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId);
HttpClient::Response DrainResp = Client.Post(DrainUrl);
- REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText()));
+ REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", DrainResp.StatusCode, DrainResp.ToText()));
CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining");
// Second submission should be rejected with 424
@@ -965,7 +1097,7 @@ TEST_CASE("function.priority")
// jobs by priority when the slot becomes free.
const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString());
HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000));
- REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode)));
+ REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", BlockerResp.StatusCode));
// Submit 3 low-priority Rot13 jobs
const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString());
@@ -982,7 +1114,7 @@ TEST_CASE("function.priority")
REQUIRE_MESSAGE(LowResp3, "Low-priority job 3 submission failed");
const int LsnLow3 = LowResp3.AsObject()["lsn"sv].AsInt32();
- // Submit 1 high-priority Rot13 job — should execute before the low-priority ones
+ // Submit 1 high-priority Rot13 job - should execute before the low-priority ones
const std::string HighJobUrl = fmt::format("/queues/{}/jobs/{}?priority=10", QueueId, WorkerId.ToHexString());
HttpClient::Response HighResp = Client.Post(HighJobUrl, BuildRot13ActionPackage("high"sv));
REQUIRE_MESSAGE(HighResp, "High-priority job submission failed");
@@ -1104,6 +1236,305 @@ TEST_CASE("function.priority")
}
//////////////////////////////////////////////////////////////////////////
+// Process exit code tests
+//
+// These tests exercise how the compute service handles worker processes
+// that exit with non-zero exit codes, including retry behaviour and
+// final failure reporting.
+
+TEST_CASE("function.exit_code")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ auto CreateQueue = [&](int MaxRetries) -> std::pair<int, std::string> {
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << MaxRetries;
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ return {QueueId, fmt::format("/queues/{}", QueueId)};
+ };
+
+ // Scenario 1: failed_action - immediate failure with max_retries=0
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(0);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(42));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Fail job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Fail job submission");
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ bool FoundInHistory = false;
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ FoundInHistory = true;
+ break;
+ }
+ }
+ CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn));
+
+ const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
+ HttpClient::Response ResultResp = Client.Get(ResultUrl);
+ CHECK_EQ(ResultResp.StatusCode, HttpResponseCode::OK);
+ }
+
+ // Scenario 2: auto_retry - retried twice before permanent failure
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(2);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(1));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000),
+ fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 2);
+ break;
+ }
+ }
+
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+ }
+
+ // Scenario 3: reschedule_failed - manual reschedule rejected after retry limit
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(1);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildFailActionPackage(7));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Fail job submission failed: status={}", SubmitResp.StatusCode));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000),
+ fmt::format("LSN {} did not appear in queue completed list\nServer log:\n{}", Lsn, Instance.GetLogOutput()));
+
+ const std::string RescheduleUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn);
+ HttpClient::Response RescheduleResp = Client.Post(RescheduleUrl);
+ CHECK_EQ(RescheduleResp.StatusCode, HttpResponseCode::Conflict);
+ }
+
+ // Scenario 4: mixed_success_and_failure - one success and one failure in the same queue
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(0);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+
+ HttpClient::Response SuccessResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv));
+ REQUIRE_MESSAGE(SuccessResp, "Rot13 job submission failed");
+ const int LsnSuccess = SuccessResp.AsObject()["lsn"sv].AsInt32();
+
+ HttpClient::Response FailResp = Client.Post(JobUrl, BuildFailActionPackage(1));
+ REQUIRE_MESSAGE(FailResp, "Fail job submission failed");
+ const int LsnFail = FailResp.AsObject()["lsn"sv].AsInt32();
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnSuccess),
+ fmt::format("Success LSN {} did not complete\nServer log:\n{}", LsnSuccess, Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, LsnFail),
+ fmt::format("Fail LSN {} did not complete\nServer log:\n{}", LsnFail, Instance.GetLogOutput()));
+
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0);
+ }
+}
+
+TEST_CASE("function.crash")
+{
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady();
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ auto CreateQueue = [&](int MaxRetries) -> std::pair<int, std::string> {
+ CbObjectWriter ConfigWriter;
+ ConfigWriter << "max_retries"sv << MaxRetries;
+ CbObjectWriter BodyWriter;
+ BodyWriter << "config"sv << ConfigWriter.Save();
+ HttpClient::Response CreateResp = Client.Post("/queues"sv, BodyWriter.Save());
+ REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}", CreateResp.StatusCode));
+ const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32();
+ return {QueueId, fmt::format("/queues/{}", QueueId)};
+ };
+
+ // Scenario 1: abort - worker process calls std::abort(), no retries
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(0);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission");
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ bool FoundInHistory = false;
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ FoundInHistory = true;
+ break;
+ }
+ }
+ CHECK_MESSAGE(FoundInHistory, fmt::format("LSN {} not found in action history", Lsn));
+ }
+
+ // Scenario 2: nullptr - worker process dereferences null, no retries
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(0);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("nullptr"sv));
+ REQUIRE_MESSAGE(SubmitResp,
+ fmt::format("Crash job submission failed: status={}, body={}", SubmitResp.StatusCode, SubmitResp.ToText()));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Crash job submission");
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn),
+ fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+ }
+
+ // Scenario 3: auto_retry - crash retried once before permanent failure
+ {
+ auto [QueueId, QueueUrl] = CreateQueue(1);
+
+ const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
+ HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildCrashActionPackage("abort"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Crash job submission failed: status={}", SubmitResp.StatusCode));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+
+ const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId);
+ REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn, 60'000),
+ fmt::format("LSN {} did not appear in queue {} completed list after retries\nServer log:\n{}",
+ Lsn,
+ QueueId,
+ Instance.GetLogOutput()));
+
+ const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId);
+ HttpClient::Response HistoryResp = Client.Get(HistoryUrl);
+ REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history");
+
+ for (auto& Item : HistoryResp.AsObject()["history"sv])
+ {
+ if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn)
+ {
+ CHECK_EQ(Item.AsObjectView()["succeeded"sv].AsBool(), false);
+ CHECK_EQ(Item.AsObjectView()["retry_count"sv].AsInt32(), 1);
+ break;
+ }
+ }
+
+ HttpClient::Response StatusResp = Client.Get(QueueUrl);
+ REQUIRE_MESSAGE(StatusResp, "Failed to get queue status");
+
+ CbObject QueueStatus = StatusResp.AsObject();
+ CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 1);
+ CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
// Remote worker synchronization tests
//
// These tests exercise the orchestrator discovery path where new compute
@@ -1139,7 +1570,6 @@ TEST_CASE("function.remote.worker_sync_on_discovery")
// Trigger immediate orchestrator re-query and wait for runner setup
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Submit Rot13 action via session
CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver);
@@ -1162,9 +1592,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery")
Sleep(200);
}
- REQUIRE_MESSAGE(
- ResultCode == HttpResponseCode::OK,
- fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput()));
REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
@@ -1199,7 +1628,6 @@ TEST_CASE("function.remote.late_runner_discovery")
// Wait for W1 discovery
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Baseline: submit Rot13 action and verify it completes on W1
{
@@ -1241,27 +1669,37 @@ TEST_CASE("function.remote.late_runner_discovery")
// Wait for W2 discovery
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
- // Verify W2 received the worker by querying its /compute/workers endpoint directly
+ // Poll W2 until the worker has been synced (SyncWorkersToRunner is async)
{
- const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2);
- HttpClient Client(ComputeBaseUri);
- HttpClient::Response ListResp = Client.Get("/workers"sv);
- REQUIRE_MESSAGE(ListResp, "Failed to list workers on W2");
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2);
+ HttpClient Client(ComputeBaseUri);
+ bool WorkerFound = false;
+ Stopwatch Timer;
- bool WorkerFound = false;
- for (auto& Item : ListResp.AsObject()["workers"sv])
+ while (Timer.GetElapsedTimeMs() < 10'000)
{
- if (Item.AsHash() == WorkerPackage.GetObjectHash())
+ HttpClient::Response ListResp = Client.Get("/workers"sv);
+ if (ListResp)
+ {
+ for (auto& Item : ListResp.AsObject()["workers"sv])
+ {
+ if (Item.AsHash() == WorkerPackage.GetObjectHash())
+ {
+ WorkerFound = true;
+ break;
+ }
+ }
+ }
+ if (WorkerFound)
{
- WorkerFound = true;
break;
}
+ Sleep(50);
}
REQUIRE_MESSAGE(WorkerFound,
- fmt::format("Worker not found on W2 after discovery — SyncWorkersToRunner may have failed\nW2 log:\n{}",
+ fmt::format("Worker not found on W2 after discovery - SyncWorkersToRunner may have failed\nW2 log:\n{}",
Instance2.GetLogOutput()));
}
@@ -1322,7 +1760,6 @@ TEST_CASE("function.remote.queue_association")
// Wait for scheduler to discover the runner
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Create a local queue and submit action to it
auto QueueResult = Session.CreateQueue();
@@ -1349,9 +1786,8 @@ TEST_CASE("function.remote.queue_association")
Sleep(200);
}
- REQUIRE_MESSAGE(
- ResultCode == HttpResponseCode::OK,
- fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput()));
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", ResultCode, Instance.GetLogOutput()));
REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput()));
CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
@@ -1401,7 +1837,6 @@ TEST_CASE("function.remote.queue_cancel_propagation")
// Wait for scheduler to discover the runner
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Create a local queue and submit a long-running Sleep action
auto QueueResult = Session.CreateQueue();
@@ -1414,9 +1849,9 @@ TEST_CASE("function.remote.queue_cancel_propagation")
REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
// Wait for the action to start running on the remote
- Sleep(2'000);
+ WaitForActionRunning(Session);
- // Cancel the local queue — this should propagate to the remote
+ // Cancel the local queue - this should propagate to the remote
Session.CancelQueue(QueueId);
// Poll for the action to complete (as cancelled)
@@ -1481,13 +1916,13 @@ TEST_CASE("function.abandon_running_http")
const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString());
HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode)));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", SubmitResp.StatusCode));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN");
// Wait for the process to start running
- Sleep(1'000);
+ WaitForAnyActionRunningHttp(Client);
// Verify the ready endpoint returns OK before abandon
{
@@ -1498,7 +1933,7 @@ TEST_CASE("function.abandon_running_http")
// Trigger abandon via the HTTP endpoint
HttpClient::Response AbandonResp = Client.Post("/abandon"sv);
REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK,
- fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText()));
+ fmt::format("Abandon request failed: status={}, body={}", AbandonResp.StatusCode, AbandonResp.ToText()));
// Ready endpoint should now return 503
{
@@ -1529,7 +1964,7 @@ TEST_CASE("function.abandon_running_http")
CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state");
}
-TEST_CASE("function.session.abandon_pending")
+TEST_CASE("function.session.abandon_pending" * doctest::skip())
{
// Create a session with no runners so actions stay pending
InMemoryChunkResolver Resolver;
@@ -1540,7 +1975,7 @@ TEST_CASE("function.session.abandon_pending")
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
- // Enqueue several actions — they will stay pending because there are no runners
+ // Enqueue several actions - they will stay pending because there are no runners
auto QueueResult = Session.CreateQueue();
REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue");
@@ -1553,7 +1988,7 @@ TEST_CASE("function.session.abandon_pending")
REQUIRE_MESSAGE(Enqueue2, "Failed to enqueue action 2");
REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3");
- // Transition to Abandoned — should mark all pending actions as Abandoned
+ // Transition to Abandoned - should mark all pending actions as Abandoned
bool Transitioned = Session.Abandon();
CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned);
@@ -1577,7 +2012,7 @@ TEST_CASE("function.session.abandon_pending")
}
Sleep(100);
}
- CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code)));
+ CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, Code));
}
// Queue should show 0 active, 3 abandoned
@@ -1589,7 +2024,7 @@ TEST_CASE("function.session.abandon_pending")
auto Rejected = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0);
CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected in Abandoned state");
- // Abandoned → Sunset should be valid
+ // Abandoned -> Sunset should be valid
CHECK(Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Sunset));
Session.Shutdown();
@@ -1618,7 +2053,6 @@ TEST_CASE("function.session.abandon_running")
// Wait for scheduler to discover the runner
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Create a queue and submit a long-running Sleep action
auto QueueResult = Session.CreateQueue();
@@ -1631,9 +2065,9 @@ TEST_CASE("function.session.abandon_running")
REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
// Wait for the action to start running on the remote
- Sleep(2'000);
+ WaitForActionRunning(Session);
- // Transition to Abandoned — should abandon the running action
+ // Transition to Abandoned - should abandon the running action
bool Transitioned = Session.Abandon();
CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
CHECK(!Session.IsHealthy());
@@ -1689,7 +2123,6 @@ TEST_CASE("function.remote.abandon_propagation")
// Wait for scheduler to discover the runner
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Create a local queue and submit a long-running Sleep action
auto QueueResult = Session.CreateQueue();
@@ -1702,9 +2135,9 @@ TEST_CASE("function.remote.abandon_propagation")
REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
// Wait for the action to start running on the remote
- Sleep(2'000);
+ WaitForActionRunning(Session);
- // Transition to Abandoned — should abandon the running action and propagate
+ // Transition to Abandoned - should abandon the running action and propagate
bool Transitioned = Session.Abandon();
CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
@@ -1762,7 +2195,6 @@ TEST_CASE("function.remote.shutdown_cancels_queues")
Session.RegisterWorker(WorkerPackage);
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Create a queue and submit a long-running action so the remote queue is established
auto QueueResult = Session.CreateQueue();
@@ -1775,7 +2207,7 @@ TEST_CASE("function.remote.shutdown_cancels_queues")
REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
// Wait for the action to start running on the remote
- Sleep(2'000);
+ WaitForActionRunning(Session);
// Verify the remote has a non-implicit queue before shutdown
HttpClient RemoteClient(Instance.GetBaseUri() + "/compute");
@@ -1795,7 +2227,7 @@ TEST_CASE("function.remote.shutdown_cancels_queues")
REQUIRE_MESSAGE(RemoteQueueFound, "Expected remote queue to exist before shutdown");
}
- // Shut down the session — this should cancel all remote queues
+ // Shut down the session - this should cancel all remote queues
Session.Shutdown();
// Verify the remote queue is now cancelled
@@ -1837,7 +2269,6 @@ TEST_CASE("function.remote.shutdown_rejects_new_work")
// Wait for runner discovery
Session.NotifyOrchestratorChanged();
- Sleep(2'000);
// Baseline: submit an action and verify it completes
{
@@ -1865,7 +2296,7 @@ TEST_CASE("function.remote.shutdown_rejects_new_work")
CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
}
- // Shut down — the remote runner should now reject new work
+ // Shut down - the remote runner should now reject new work
Session.Shutdown();
// Attempting to enqueue after shutdown should fail (session is in Sunset state)
@@ -1894,7 +2325,7 @@ TEST_CASE("function.session.retract_pending")
REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action");
// Let the scheduler process the pending action
- Sleep(500);
+ Sleep(50);
// Retract the pending action
auto Result = Session.RetractAction(Enqueued.Lsn);
@@ -1903,7 +2334,7 @@ TEST_CASE("function.session.retract_pending")
// The action should be re-enqueued as pending (still no runners, so stays pending).
// Let the scheduler process the retracted action back to pending.
- Sleep(500);
+ Sleep(50);
// Queue should still show 1 active (the action was rescheduled, not completed)
auto Status = Session.GetQueueStatus(QueueResult.QueueId);
@@ -1955,7 +2386,7 @@ TEST_CASE("function.session.retract_not_terminal")
REQUIRE_MESSAGE(Code == HttpResponseCode::OK, "Action did not complete within timeout");
- // Retract should fail — action already completed (no longer in pending/running maps)
+ // Retract should fail - action already completed (no longer in pending/running maps)
auto RetractResult = Session.RetractAction(Enqueued.Lsn);
CHECK(!RetractResult.Success);
@@ -1979,24 +2410,24 @@ TEST_CASE("function.retract_http")
// Submit a long-running Sleep action to occupy the single execution slot
const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString());
HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000));
- REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode)));
+ REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", BlockerResp.StatusCode));
- // Submit a second action — it will stay pending because the slot is occupied
+ // Submit a second action - it will stay pending because the slot is occupied
HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv));
- REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode)));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", SubmitResp.StatusCode));
const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission");
- // Wait for the scheduler to process the pending action into m_PendingActions
- Sleep(1'000);
+ // Wait for the blocker action to start running (occupying the single slot)
+ WaitForAnyActionRunningHttp(Client);
// Retract the pending action via POST /jobs/{lsn}/retract
const std::string RetractUrl = fmt::format("/jobs/{}/retract", Lsn);
HttpClient::Response RetractResp = Client.Post(RetractUrl);
CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK,
fmt::format("Retract failed: status={}, body={}\nServer log:\n{}",
- int(RetractResp.StatusCode),
+ RetractResp.StatusCode,
RetractResp.ToText(),
Instance.GetLogOutput()));
@@ -2008,10 +2439,45 @@ TEST_CASE("function.retract_http")
}
// A second retract should also succeed (action is back to pending)
- Sleep(500);
+ Sleep(50);
HttpClient::Response RetractResp2 = Client.Post(RetractUrl);
CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK,
- fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText()));
+ fmt::format("Second retract failed: status={}, body={}", RetractResp2.StatusCode, RetractResp2.ToText()));
+}
+
+TEST_CASE("function.session.immediate_query_after_enqueue")
+{
+ // Verify that actions are immediately visible to GetActionResult and
+ // FindActionResult right after enqueue, without waiting for the
+ // scheduler thread to process the update.
+
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.Ready();
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ CbObject ActionObj = BuildRot13ActionForSession("immediate-query"sv, Resolver);
+
+ auto EnqueueRes = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Failed to enqueue action");
+
+ // Query by LSN immediately - must not return NotFound
+ CbPackage Result;
+ HttpResponseCode Code = Session.GetActionResult(EnqueueRes.Lsn, Result);
+ CHECK_MESSAGE(Code == HttpResponseCode::Accepted,
+ fmt::format("GetActionResult returned {} immediately after enqueue, expected Accepted", Code));
+
+ // Query by ActionId immediately - must not return NotFound
+ const IoHash ActionId = ActionObj.GetHash();
+ CbPackage FindResult;
+ HttpResponseCode FindCode = Session.FindActionResult(ActionId, FindResult);
+ CHECK_MESSAGE(FindCode == HttpResponseCode::Accepted,
+ fmt::format("FindActionResult returned {} immediately after enqueue, expected Accepted", FindCode));
+
+ Session.Shutdown();
}
TEST_SUITE_END();
diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp
index b2da552fc..35a840e5d 100644
--- a/src/zenserver-test/hub-tests.cpp
+++ b/src/zenserver-test/hub-tests.cpp
@@ -329,17 +329,36 @@ TEST_CASE("hub.lifecycle.children")
CHECK_EQ(Result.AsText(), "GhijklmNop"sv);
}
- Result = Client.Post("modules/abc/deprovision");
+ // Deprovision all modules at once
+ Result = Client.Post("deprovision");
REQUIRE(Result);
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::Accepted);
+ {
+ CbObject Body = Result.AsObject();
+ CbArrayView AcceptedArr = Body["Accepted"].AsArrayView();
+ CHECK_EQ(AcceptedArr.Num(), 2u);
+ bool FoundAbc = false;
+ bool FoundDef = false;
+ for (CbFieldView F : AcceptedArr)
+ {
+ if (F.AsString() == "abc"sv)
+ {
+ FoundAbc = true;
+ }
+ else if (F.AsString() == "def"sv)
+ {
+ FoundDef = true;
+ }
+ }
+ CHECK(FoundAbc);
+ CHECK(FoundDef);
+ }
REQUIRE(WaitForModuleGone(Client, "abc"));
+ REQUIRE(WaitForModuleGone(Client, "def"));
{
HttpClient ModClient(fmt::format("http://localhost:{}", AbcPort), kFastTimeout);
CHECK(WaitForPortUnreachable(ModClient));
}
-
- Result = Client.Post("modules/def/deprovision");
- REQUIRE(Result);
- REQUIRE(WaitForModuleGone(Client, "def"));
{
HttpClient ModClient(fmt::format("http://localhost:{}", DefPort), kFastTimeout);
CHECK(WaitForPortUnreachable(ModClient));
@@ -349,6 +368,10 @@ TEST_CASE("hub.lifecycle.children")
Result = Client.Get("status");
REQUIRE(Result);
CHECK_EQ(Result.AsObject()["modules"].AsArrayView().Num(), 0u);
+
+ // Deprovision-all with no modules
+ Result = Client.Post("deprovision");
+ CHECK(Result);
}
static bool
@@ -377,7 +400,7 @@ TEST_CASE("hub.consul.kv")
consul::ConsulProcess ConsulProc;
ConsulProc.SpawnConsulAgent();
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
Client.SetKeyValue("zen/hub/testkey", "testvalue");
std::string RetrievedValue = Client.GetKeyValue("zen/hub/testkey");
@@ -399,7 +422,7 @@ TEST_CASE("hub.consul.hub.registration")
"--consul-health-interval-seconds=5 --consul-deregister-after-seconds=60");
REQUIRE(PortNumber != 0);
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000));
// Verify custom intervals flowed through to the registered check
@@ -480,7 +503,7 @@ TEST_CASE("hub.consul.hub.registration.token")
// Use a plain client -- dev-mode Consul doesn't enforce ACLs, but the
// server has exercised the ConsulTokenEnv -> GetEnvVariable -> ConsulClient path.
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000));
@@ -501,7 +524,7 @@ TEST_CASE("hub.consul.provision.registration")
Instance.SpawnServerAndWaitUntilReady("--consul-endpoint=http://localhost:8500/ --instance-id=test-instance");
REQUIRE(PortNumber != 0);
- consul::ConsulClient Client("http://localhost:8500/");
+ consul::ConsulClient Client({.BaseUri = "http://localhost:8500/"});
REQUIRE(WaitForConsulService(Client, "zen-hub-test-instance", true, 5000));
@@ -762,9 +785,10 @@ TEST_CASE("hub.hibernate.errors")
CHECK(!Result);
CHECK_EQ(Result.StatusCode, HttpResponseCode::NotFound);
+ // Obliterate of an unknown module succeeds (cleans up backend data for dehydrated modules)
Result = Client.Delete("modules/unknown");
- CHECK(!Result);
- CHECK_EQ(Result.StatusCode, HttpResponseCode::NotFound);
+ CHECK(Result);
+ CHECK_EQ(Result.StatusCode, HttpResponseCode::Accepted);
// Double-provision: second call while first is in-flight returns 202 Accepted with the same port.
Result = Client.Post("modules/errmod/provision");
diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp
index 2e530ff92..cb0926ddc 100644
--- a/src/zenserver-test/logging-tests.cpp
+++ b/src/zenserver-test/logging-tests.cpp
@@ -69,8 +69,8 @@ TEST_CASE("logging.file.default")
// --quiet sets the console sink level to WARN. The formatted "[info] ..."
// entry written by the default logger's console sink must therefore not appear
-// in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_*
-// macros — may still emit plain-text messages without a level marker, so we
+// in captured stdout. (The "console" named logger - used by ZEN_CONSOLE_*
+// macros - may still emit plain-text messages without a level marker, so we
// check for the absence of the FullFormatter "[info]" prefix rather than the
// message text itself.)
TEST_CASE("logging.console.quiet")
@@ -146,28 +146,6 @@ TEST_CASE("logging.file.json")
CHECK_MESSAGE(LogContains(FileLog, "\"message\""), FileLog);
CHECK_MESSAGE(LogContains(FileLog, "\"source\": \"zenserver\""), FileLog);
CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog);
-}
-
-// --log-id <id> is automatically set to the server instance name in test mode.
-// The JSON formatter emits this value as the "id" field, so every entry in a
-// .json log file must carry a non-empty "id".
-TEST_CASE("logging.log_id")
-{
- const std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
- const std::filesystem::path LogFile = TestDir / "test.json";
-
- ZenServerInstance Instance(TestEnv);
- Instance.SetDataDir(TestDir);
-
- const std::string LogArg = fmt::format("--abslog {}", LogFile.string());
- const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg);
- CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
-
- Instance.Shutdown();
-
- CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created");
- const std::string FileLog = ReadFileToString(LogFile);
- // The JSON formatter writes the log-id as: "id": "<value>",
CHECK_MESSAGE(LogContains(FileLog, "\"id\": \""), FileLog);
}
diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp
index 1f6a7675c..99c92e15f 100644
--- a/src/zenserver-test/objectstore-tests.cpp
+++ b/src/zenserver-test/objectstore-tests.cpp
@@ -19,18 +19,22 @@ using namespace std::literals;
TEST_SUITE_BEGIN("server.objectstore");
-TEST_CASE("objectstore.blobs")
+TEST_CASE("objectstore")
{
- std::string_view Bucket = "bkt"sv;
+ ZenServerInstance Instance(TestEnv);
+
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--objectstore-enabled");
+ CHECK(Port != 0);
- std::vector<IoHash> CompressedBlobsHashes;
- std::vector<uint64_t> BlobsSizes;
- std::vector<uint64_t> CompressedBlobsSizes;
+ // --- objectstore.blobs ---
{
- ZenServerInstance Instance(TestEnv);
+ INFO("objectstore.blobs");
+
+ std::string_view Bucket = "bkt"sv;
- const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--objectstore-enabled"));
- CHECK(PortNumber != 0);
+ std::vector<IoHash> CompressedBlobsHashes;
+ std::vector<uint64_t> BlobsSizes;
+ std::vector<uint64_t> CompressedBlobsSizes;
HttpClient Client(Instance.GetBaseUri() + "/obj/");
@@ -68,94 +72,238 @@ TEST_CASE("objectstore.blobs")
CHECK_EQ(RawSize, BlobsSizes[I]);
}
}
+
+ // --- objectstore.s3client ---
+ {
+ INFO("objectstore.s3client");
+
+ S3ClientOptions Opts;
+ Opts.BucketName = "s3test";
+ Opts.Region = "us-east-1";
+ Opts.Endpoint = fmt::format("http://localhost:{}/obj/bucket", Port);
+ Opts.PathStyle = true;
+ Opts.Credentials.AccessKeyId = "testkey";
+ Opts.Credentials.SecretAccessKey = "testsecret";
+
+ S3Client Client(Opts);
+
+ // -- PUT + GET roundtrip --
+ std::string_view TestData = "hello from s3client via objectstore"sv;
+ IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData));
+ S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content));
+ REQUIRE_MESSAGE(PutRes.IsSuccess(), PutRes.Error);
+
+ S3GetObjectResult GetRes = Client.GetObject("test/hello.txt");
+ REQUIRE_MESSAGE(GetRes.IsSuccess(), GetRes.Error);
+ CHECK(GetRes.AsText() == TestData);
+
+ // -- PUT overwrites --
+ IoBuffer Original = IoBufferBuilder::MakeFromMemory(MakeMemoryView("original"sv));
+ IoBuffer Overwrite = IoBufferBuilder::MakeFromMemory(MakeMemoryView("overwritten"sv));
+ REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Original)).IsSuccess());
+ REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Overwrite)).IsSuccess());
+
+ S3GetObjectResult OverwriteGet = Client.GetObject("overwrite/file.txt");
+ REQUIRE(OverwriteGet.IsSuccess());
+ CHECK(OverwriteGet.AsText() == "overwritten"sv);
+
+ // -- GET not found --
+ S3GetObjectResult NotFoundGet = Client.GetObject("nonexistent/file.dat");
+ CHECK_FALSE(NotFoundGet.IsSuccess());
+
+ // -- HEAD found --
+ std::string_view HeadData = "head test data"sv;
+ IoBuffer HeadContent = IoBufferBuilder::MakeFromMemory(MakeMemoryView(HeadData));
+ REQUIRE(Client.PutObject("head/meta.txt", std::move(HeadContent)).IsSuccess());
+
+ S3HeadObjectResult HeadRes = Client.HeadObject("head/meta.txt");
+ REQUIRE_MESSAGE(HeadRes.IsSuccess(), HeadRes.Error);
+ CHECK(HeadRes.Status == HeadObjectResult::Found);
+ CHECK(HeadRes.Info.Size == HeadData.size());
+
+ // -- HEAD not found --
+ S3HeadObjectResult HeadNotFound = Client.HeadObject("nonexistent/file.dat");
+ CHECK(HeadNotFound.IsSuccess());
+ CHECK(HeadNotFound.Status == HeadObjectResult::NotFound);
+
+ // -- LIST objects --
+ for (int i = 0; i < 3; ++i)
+ {
+ std::string Key = fmt::format("listing/item-{}.txt", i);
+ std::string Payload = fmt::format("content-{}", i);
+ IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload));
+ REQUIRE(Client.PutObject(Key, std::move(Buf)).IsSuccess());
+ }
+
+ S3ListObjectsResult ListRes = Client.ListObjects("listing/");
+ REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error);
+ REQUIRE(ListRes.Objects.size() == 3);
+
+ std::vector<std::string> Keys;
+ for (const S3ObjectInfo& Obj : ListRes.Objects)
+ {
+ Keys.push_back(Obj.Key);
+ CHECK(Obj.Size > 0);
+ }
+ std::sort(Keys.begin(), Keys.end());
+ CHECK(Keys[0] == "listing/item-0.txt");
+ CHECK(Keys[1] == "listing/item-1.txt");
+ CHECK(Keys[2] == "listing/item-2.txt");
+
+ // -- LIST empty prefix --
+ S3ListObjectsResult EmptyList = Client.ListObjects("no-such-prefix/");
+ REQUIRE(EmptyList.IsSuccess());
+ CHECK(EmptyList.Objects.empty());
+ }
+
+ // --- objectstore.range-requests ---
+ {
+ INFO("objectstore.range-requests");
+
+ HttpClient Client(Instance.GetBaseUri() + "/obj/");
+
+ IoBuffer Blob = CreateRandomBlob(1024);
+ MemoryView BlobView = Blob.GetView();
+ std::string ObjectPath = "bucket/bkt/range-test/data.bin";
+
+ HttpClient::Response PutResult = Client.Put(ObjectPath, IoBuffer(Blob));
+ REQUIRE(PutResult);
+
+ // Full GET without Range header
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath);
+ CHECK(Result.StatusCode == HttpResponseCode::OK);
+ CHECK_EQ(Result.ResponsePayload.GetSize(), 1024u);
+ CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView));
+ }
+
+ // Single range: bytes 100-199
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=100-199"}});
+ CHECK(Result.StatusCode == HttpResponseCode::PartialContent);
+ CHECK_EQ(Result.ResponsePayload.GetSize(), 100u);
+ CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView.Mid(100, 100)));
+ }
+
+ // Range starting at zero: bytes 0-49
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=0-49"}});
+ CHECK(Result.StatusCode == HttpResponseCode::PartialContent);
+ CHECK_EQ(Result.ResponsePayload.GetSize(), 50u);
+ CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView.Mid(0, 50)));
+ }
+
+ // Range at end of file: bytes 1000-1023
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=1000-1023"}});
+ CHECK(Result.StatusCode == HttpResponseCode::PartialContent);
+ CHECK_EQ(Result.ResponsePayload.GetSize(), 24u);
+ CHECK(Result.ResponsePayload.GetView().EqualBytes(BlobView.Mid(1000, 24)));
+ }
+
+ // Multiple ranges: bytes 0-49 and 100-149
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=0-49,100-149"}});
+ CHECK(Result.StatusCode == HttpResponseCode::PartialContent);
+
+ std::string_view Body(reinterpret_cast<const char*>(Result.ResponsePayload.GetData()), Result.ResponsePayload.GetSize());
+
+ // Verify multipart structure contains both range payloads
+ CHECK(Body.find("Content-Range: bytes 0-49/1024") != std::string_view::npos);
+ CHECK(Body.find("Content-Range: bytes 100-149/1024") != std::string_view::npos);
+
+ // Extract and verify actual data for first range
+ auto FindPartData = [&](std::string_view ContentRange) -> std::string_view {
+ size_t Pos = Body.find(ContentRange);
+ if (Pos == std::string_view::npos)
+ {
+ return {};
+ }
+ // Skip past the Content-Range line and the blank line separator
+ Pos = Body.find("\r\n\r\n", Pos);
+ if (Pos == std::string_view::npos)
+ {
+ return {};
+ }
+ Pos += 4;
+ size_t End = Body.find("\r\n--", Pos);
+ if (End == std::string_view::npos)
+ {
+ return {};
+ }
+ return Body.substr(Pos, End - Pos);
+ };
+
+ std::string_view Part1 = FindPartData("Content-Range: bytes 0-49/1024");
+ CHECK_EQ(Part1.size(), 50u);
+ CHECK(MemoryView(Part1.data(), Part1.size()).EqualBytes(BlobView.Mid(0, 50)));
+
+ std::string_view Part2 = FindPartData("Content-Range: bytes 100-149/1024");
+ CHECK_EQ(Part2.size(), 50u);
+ CHECK(MemoryView(Part2.data(), Part2.size()).EqualBytes(BlobView.Mid(100, 50)));
+ }
+
+ // Out-of-bounds single range
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=2000-2099"}});
+ CHECK(Result.StatusCode == HttpResponseCode::RangeNotSatisfiable);
+ }
+
+ // Out-of-bounds multi-range
+ {
+ HttpClient::Response Result = Client.Get(ObjectPath, {{"Range", "bytes=0-49,2000-2099"}});
+ CHECK(Result.StatusCode == HttpResponseCode::RangeNotSatisfiable);
+ }
+ }
}
-TEST_CASE("objectstore.s3client")
+TEST_CASE("objectstore.range-requests-download")
{
ZenServerInstance Instance(TestEnv);
const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--objectstore-enabled");
- CHECK_MESSAGE(Port != 0, Instance.GetLogOutput());
-
- // S3Client in path-style builds paths as /{bucket}/{key}.
- // The objectstore routes objects at bucket/{bucket}/{key} relative to its base.
- // Point the S3Client endpoint at {server}/obj/bucket so the paths line up.
- S3ClientOptions Opts;
- Opts.BucketName = "s3test";
- Opts.Region = "us-east-1";
- Opts.Endpoint = fmt::format("http://localhost:{}/obj/bucket", Port);
- Opts.PathStyle = true;
- Opts.Credentials.AccessKeyId = "testkey";
- Opts.Credentials.SecretAccessKey = "testsecret";
-
- S3Client Client(Opts);
-
- // -- PUT + GET roundtrip --
- std::string_view TestData = "hello from s3client via objectstore"sv;
- IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData));
- S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content));
- REQUIRE_MESSAGE(PutRes.IsSuccess(), PutRes.Error);
-
- S3GetObjectResult GetRes = Client.GetObject("test/hello.txt");
- REQUIRE_MESSAGE(GetRes.IsSuccess(), GetRes.Error);
- CHECK(GetRes.AsText() == TestData);
-
- // -- PUT overwrites --
- IoBuffer Original = IoBufferBuilder::MakeFromMemory(MakeMemoryView("original"sv));
- IoBuffer Overwrite = IoBufferBuilder::MakeFromMemory(MakeMemoryView("overwritten"sv));
- REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Original)).IsSuccess());
- REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Overwrite)).IsSuccess());
-
- S3GetObjectResult OverwriteGet = Client.GetObject("overwrite/file.txt");
- REQUIRE(OverwriteGet.IsSuccess());
- CHECK(OverwriteGet.AsText() == "overwritten"sv);
-
- // -- GET not found --
- S3GetObjectResult NotFoundGet = Client.GetObject("nonexistent/file.dat");
- CHECK_FALSE(NotFoundGet.IsSuccess());
-
- // -- HEAD found --
- std::string_view HeadData = "head test data"sv;
- IoBuffer HeadContent = IoBufferBuilder::MakeFromMemory(MakeMemoryView(HeadData));
- REQUIRE(Client.PutObject("head/meta.txt", std::move(HeadContent)).IsSuccess());
-
- S3HeadObjectResult HeadRes = Client.HeadObject("head/meta.txt");
- REQUIRE_MESSAGE(HeadRes.IsSuccess(), HeadRes.Error);
- CHECK(HeadRes.Status == HeadObjectResult::Found);
- CHECK(HeadRes.Info.Size == HeadData.size());
-
- // -- HEAD not found --
- S3HeadObjectResult HeadNotFound = Client.HeadObject("nonexistent/file.dat");
- CHECK(HeadNotFound.IsSuccess());
- CHECK(HeadNotFound.Status == HeadObjectResult::NotFound);
-
- // -- LIST objects --
- for (int i = 0; i < 3; ++i)
+ REQUIRE(Port != 0);
+
+ HttpClient Client(Instance.GetBaseUri() + "/obj/");
+
+ IoBuffer Blob = CreateRandomBlob(1024);
+ MemoryView BlobView = Blob.GetView();
+ std::string ObjectPath = "bucket/bkt/range-download-test/data.bin";
+
+ HttpClient::Response PutResult = Client.Put(ObjectPath, IoBuffer(Blob));
+ REQUIRE(PutResult);
+
+ ScopedTemporaryDirectory DownloadDir;
+
+ // Single range via Download: verify Ranges is populated and GetRanges maps correctly
{
- std::string Key = fmt::format("listing/item-{}.txt", i);
- std::string Payload = fmt::format("content-{}", i);
- IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload));
- REQUIRE(Client.PutObject(Key, std::move(Buf)).IsSuccess());
- }
+ HttpClient::Response Result = Client.Download(ObjectPath, DownloadDir.Path(), {{"Range", "bytes=100-199"}});
+ CHECK(Result.StatusCode == HttpResponseCode::PartialContent);
+ REQUIRE_EQ(Result.Ranges.size(), 1u);
+ CHECK_EQ(Result.Ranges[0].RangeOffset, 100u);
+ CHECK_EQ(Result.Ranges[0].RangeLength, 100u);
- S3ListObjectsResult ListRes = Client.ListObjects("listing/");
- REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error);
- REQUIRE(ListRes.Objects.size() == 3);
+ std::vector<std::pair<uint64_t, uint64_t>> RequestedRanges = {{100, 100}};
+ std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = Result.GetRanges(RequestedRanges);
+ REQUIRE_EQ(PayloadRanges.size(), 1u);
+ CHECK(Result.ResponsePayload.GetView().Mid(PayloadRanges[0].first, PayloadRanges[0].second).EqualBytes(BlobView.Mid(100, 100)));
+ }
- std::vector<std::string> Keys;
- for (const S3ObjectInfo& Obj : ListRes.Objects)
+ // Multi-range via Download: verify Ranges is populated for both parts and GetRanges maps correctly
{
- Keys.push_back(Obj.Key);
- CHECK(Obj.Size > 0);
+ HttpClient::Response Result = Client.Download(ObjectPath, DownloadDir.Path(), {{"Range", "bytes=0-49,100-149"}});
+ CHECK(Result.StatusCode == HttpResponseCode::PartialContent);
+ REQUIRE_EQ(Result.Ranges.size(), 2u);
+ CHECK_EQ(Result.Ranges[0].RangeOffset, 0u);
+ CHECK_EQ(Result.Ranges[0].RangeLength, 50u);
+ CHECK_EQ(Result.Ranges[1].RangeOffset, 100u);
+ CHECK_EQ(Result.Ranges[1].RangeLength, 50u);
+
+ std::vector<std::pair<uint64_t, uint64_t>> RequestedRanges = {{0, 50}, {100, 50}};
+ std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = Result.GetRanges(RequestedRanges);
+ REQUIRE_EQ(PayloadRanges.size(), 2u);
+ CHECK(Result.ResponsePayload.GetView().Mid(PayloadRanges[0].first, PayloadRanges[0].second).EqualBytes(BlobView.Mid(0, 50)));
+ CHECK(Result.ResponsePayload.GetView().Mid(PayloadRanges[1].first, PayloadRanges[1].second).EqualBytes(BlobView.Mid(100, 50)));
}
- std::sort(Keys.begin(), Keys.end());
- CHECK(Keys[0] == "listing/item-0.txt");
- CHECK(Keys[1] == "listing/item-1.txt");
- CHECK(Keys[2] == "listing/item-2.txt");
-
- // -- LIST empty prefix --
- S3ListObjectsResult EmptyList = Client.ListObjects("no-such-prefix/");
- REQUIRE(EmptyList.IsSuccess());
- CHECK(EmptyList.Objects.empty());
}
TEST_SUITE_END();
diff --git a/src/zenserver-test/process-tests.cpp b/src/zenserver-test/process-tests.cpp
index ae11bb294..3f6476810 100644
--- a/src/zenserver-test/process-tests.cpp
+++ b/src/zenserver-test/process-tests.cpp
@@ -115,7 +115,7 @@ TEST_CASE("pipe.raii_cleanup")
{
StdoutPipeHandles Pipe;
REQUIRE(CreateStdoutPipe(Pipe));
- // Pipe goes out of scope here — destructor should close both ends
+ // Pipe goes out of scope here - destructor should close both ends
}
}
@@ -155,7 +155,7 @@ TEST_CASE("pipe.move_semantics")
CHECK(Moved.WriteFd == -1);
# endif
- // Assigned goes out of scope — destructor closes handles
+ // Assigned goes out of scope - destructor closes handles
}
TEST_CASE("pipe.close_is_idempotent")
diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp
index a37ecb6be..49d985abb 100644
--- a/src/zenserver-test/projectstore-tests.cpp
+++ b/src/zenserver-test/projectstore-tests.cpp
@@ -22,6 +22,7 @@ ZEN_THIRD_PARTY_INCLUDES_START
ZEN_THIRD_PARTY_INCLUDES_END
# include <random>
+# include <thread>
namespace zen::tests {
@@ -40,326 +41,429 @@ TEST_CASE("project.basic")
const uint16_t PortNumber = Instance1.SpawnServerAndWaitUntilReady();
- std::mt19937_64 mt;
-
- zen::StringBuilder<64> BaseUri;
- BaseUri << fmt::format("http://localhost:{}", PortNumber);
+ std::string ServerUri = fmt::format("http://localhost:{}", PortNumber);
std::filesystem::path BinPath = zen::GetRunningExecutablePath();
std::filesystem::path RootPath = BinPath.parent_path().parent_path();
BinPath = BinPath.lexically_relative(RootPath);
- SUBCASE("build store init")
+ auto CreateProjectAndOplog = [&](std::string_view ProjectName, std::string_view OplogName) -> std::string {
+ HttpClient Http{ServerUri};
+
+ zen::CbObjectWriter Body;
+ Body << "id" << ProjectName;
+ Body << "root" << RootPath.c_str();
+ Body << "project"
+ << "/zooom";
+ Body << "engine"
+ << "/zooom";
+ IoBuffer BodyBuf = Body.Save().GetBuffer().AsIoBuffer();
+ auto Response = Http.Post(fmt::format("/prj/{}", ProjectName), BodyBuf);
+ REQUIRE(Response.StatusCode == HttpResponseCode::Created);
+
+ std::string OplogUri = fmt::format("{}/prj/{}/oplog/{}", ServerUri, ProjectName, OplogName);
+ HttpClient OplogHttp{OplogUri};
+ auto OplogResponse = OplogHttp.Post(""sv, IoBuffer{}, ZenContentType::kCbObject);
+ REQUIRE(OplogResponse.StatusCode == HttpResponseCode::Created);
+
+ return OplogUri;
+ };
+
+ // Create a file at a path exceeding Windows MAX_PATH (260 chars) for long filename testing
+ std::filesystem::path LongPathDir = RootPath / "longpathtest";
+ for (int I = 0; I < 5; ++I)
{
- {
- HttpClient Http{BaseUri};
+ LongPathDir /= std::string(50, char('a' + I));
+ }
+ std::filesystem::path LongFilePath = LongPathDir / "testfile.bin";
+ std::filesystem::path LongRelPath = LongFilePath.lexically_relative(RootPath);
- {
- zen::CbObjectWriter Body;
- Body << "id"
- << "test";
- Body << "root" << RootPath.c_str();
- Body << "project"
- << "/zooom";
- Body << "engine"
- << "/zooom";
-
- zen::BinaryWriter MemOut;
- IoBuffer BodyBuf = Body.Save().GetBuffer().AsIoBuffer();
-
- auto Response = Http.Post("/prj/test"sv, BodyBuf);
- CHECK(Response.StatusCode == HttpResponseCode::Created);
- }
+ const uint8_t LongPathFileData[] = {0xDE, 0xAD, 0xBE, 0xEF};
+ CreateDirectories(MakeSafeAbsolutePath(LongPathDir));
+ WriteFile(MakeSafeAbsolutePath(LongFilePath), IoBufferBuilder::MakeCloneFromMemory(LongPathFileData, sizeof(LongPathFileData)));
+ CHECK(LongRelPath.string().length() > 260);
- {
- auto Response = Http.Get("/prj/test"sv);
- REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+ std::string LongClientPath = "/{engine}/client";
+ for (int I = 0; I < 5; ++I)
+ {
+ LongClientPath += '/';
+ LongClientPath.append(50, char('a' + I));
+ }
+ LongClientPath += "/longfile.bin";
+ CHECK(LongClientPath.length() > 260);
- CbObject ResponseObject = Response.AsObject();
+ const std::string_view LongPathChunkId{
+ "00000000"
+ "00000000"
+ "00020000"};
+ auto LongPathFileOid = zen::Oid::FromHexString(LongPathChunkId);
- CHECK(ResponseObject["id"].AsString() == "test"sv);
- CHECK(ResponseObject["root"].AsString() == PathToUtf8(RootPath.c_str()));
- }
+ // --- build store persistence ---
+ // First section also verifies project and oplog creation responses.
+ {
+ HttpClient ServerHttp{ServerUri};
+
+ {
+ zen::CbObjectWriter Body;
+ Body << "id"
+ << "test_persist";
+ Body << "root" << RootPath.c_str();
+ Body << "project"
+ << "/zooom";
+ Body << "engine"
+ << "/zooom";
+ IoBuffer BodyBuf = Body.Save().GetBuffer().AsIoBuffer();
+
+ auto Response = ServerHttp.Post("/prj/test_persist"sv, BodyBuf);
+ CHECK(Response.StatusCode == HttpResponseCode::Created);
+ }
+
+ {
+ auto Response = ServerHttp.Get("/prj/test_persist"sv);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+
+ CbObject ResponseObject = Response.AsObject();
+
+ CHECK(ResponseObject["id"].AsString() == "test_persist"sv);
+ CHECK(ResponseObject["root"].AsString() == PathToUtf8(RootPath.c_str()));
}
- BaseUri << "/prj/test/oplog/foobar";
+ std::string OplogUri = fmt::format("{}/prj/test_persist/oplog/oplog_persist", ServerUri);
{
- HttpClient Http{BaseUri};
+ HttpClient OplogHttp{OplogUri};
{
- auto Response = Http.Post(""sv, IoBuffer{}, ZenContentType::kCbObject);
+ auto Response = OplogHttp.Post(""sv, IoBuffer{}, ZenContentType::kCbObject);
CHECK(Response.StatusCode == HttpResponseCode::Created);
}
{
- auto Response = Http.Get(""sv);
+ auto Response = OplogHttp.Get(""sv);
REQUIRE(Response.StatusCode == HttpResponseCode::OK);
CbObject ResponseObject = Response.AsObject();
- CHECK(ResponseObject["id"].AsString() == "foobar"sv);
- CHECK(ResponseObject["project"].AsString() == "test"sv);
+ CHECK(ResponseObject["id"].AsString() == "oplog_persist"sv);
+ CHECK(ResponseObject["project"].AsString() == "test_persist"sv);
}
}
- // Create a file at a path exceeding Windows MAX_PATH (260 chars) for long filename testing
- std::filesystem::path LongPathDir = RootPath / "longpathtest";
- for (int I = 0; I < 5; ++I)
+ uint8_t AttachData[] = {1, 2, 3};
+
+ zen::CompressedBuffer Attachment = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone(zen::MemoryView{AttachData, 3}));
+ zen::CbAttachment Attach{Attachment, Attachment.DecodeRawHash()};
+
+ zen::CbObjectWriter OpWriter;
+ OpWriter << "key"
+ << "foo"
+ << "attachment" << Attach;
+
+ const std::string_view ChunkId{
+ "00000000"
+ "00000000"
+ "00010000"};
+ auto FileOid = zen::Oid::FromHexString(ChunkId);
+
+ OpWriter.BeginArray("files");
+ OpWriter.BeginObject();
+ OpWriter << "id" << FileOid;
+ OpWriter << "clientpath"
+ << "/{engine}/client/side/path";
+ OpWriter << "serverpath" << BinPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.BeginObject();
+ OpWriter << "id" << LongPathFileOid;
+ OpWriter << "clientpath" << LongClientPath;
+ OpWriter << "serverpath" << LongRelPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.EndArray();
+
+ zen::CbObject Op = OpWriter.Save();
+
+ zen::CbPackage OpPackage(Op);
+ OpPackage.AddAttachment(Attach);
+
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
+
+ HttpClient Http{OplogUri};
+
{
- LongPathDir /= std::string(50, char('a' + I));
- }
- std::filesystem::path LongFilePath = LongPathDir / "testfile.bin";
- std::filesystem::path LongRelPath = LongFilePath.lexically_relative(RootPath);
+ auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView()));
- const uint8_t LongPathFileData[] = {0xDE, 0xAD, 0xBE, 0xEF};
- CreateDirectories(MakeSafeAbsolutePath(LongPathDir));
- WriteFile(MakeSafeAbsolutePath(LongFilePath), IoBufferBuilder::MakeCloneFromMemory(LongPathFileData, sizeof(LongPathFileData)));
- CHECK(LongRelPath.string().length() > 260);
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::Created);
+ }
- std::string LongClientPath = "/{engine}/client";
- for (int I = 0; I < 5; ++I)
{
- LongClientPath += '/';
- LongClientPath.append(50, char('a' + I));
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << ChunkId;
+ auto Response = Http.Get(ChunkGetUri);
+
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::OK);
}
- LongClientPath += "/longfile.bin";
- CHECK(LongClientPath.length() > 260);
- const std::string_view LongPathChunkId{
- "00000000"
- "00000000"
- "00020000"};
- auto LongPathFileOid = zen::Oid::FromHexString(LongPathChunkId);
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << ChunkId << "?offset=1&size=10";
+ auto Response = Http.Get(ChunkGetUri);
+
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::OK);
+ CHECK(Response.ResponsePayload.GetSize() == 10);
+ }
- SUBCASE("build store persistence")
{
- uint8_t AttachData[] = {1, 2, 3};
-
- zen::CompressedBuffer Attachment = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone(zen::MemoryView{AttachData, 3}));
- zen::CbAttachment Attach{Attachment, Attachment.DecodeRawHash()};
-
- zen::CbObjectWriter OpWriter;
- OpWriter << "key"
- << "foo"
- << "attachment" << Attach;
-
- const std::string_view ChunkId{
- "00000000"
- "00000000"
- "00010000"};
- auto FileOid = zen::Oid::FromHexString(ChunkId);
-
- OpWriter.BeginArray("files");
- OpWriter.BeginObject();
- OpWriter << "id" << FileOid;
- OpWriter << "clientpath"
- << "/{engine}/client/side/path";
- OpWriter << "serverpath" << BinPath.c_str();
- OpWriter.EndObject();
- OpWriter.BeginObject();
- OpWriter << "id" << LongPathFileOid;
- OpWriter << "clientpath" << LongClientPath;
- OpWriter << "serverpath" << LongRelPath.c_str();
- OpWriter.EndObject();
- OpWriter.EndArray();
-
- zen::CbObject Op = OpWriter.Save();
-
- zen::CbPackage OpPackage(Op);
- OpPackage.AddAttachment(Attach);
-
- zen::BinaryWriter MemOut;
- legacy::SaveCbPackage(OpPackage, MemOut);
-
- HttpClient Http{BaseUri};
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << LongPathChunkId;
+ auto Response = Http.Get(ChunkGetUri);
- {
- auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView()));
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::OK);
+ CHECK(Response.ResponsePayload.GetSize() == sizeof(LongPathFileData));
+ }
- REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::Created);
- }
+ ZEN_INFO("+++++++");
+ }
- // Read file data
+ // --- snapshot ---
+ {
+ std::string OplogUri = CreateProjectAndOplog("test_snap", "oplog_snap");
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << ChunkId;
- auto Response = Http.Get(ChunkGetUri);
+ zen::CbObjectWriter OpWriter;
+ OpWriter << "key"
+ << "foo";
- REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
- }
+ const std::string_view ChunkId{
+ "00000000"
+ "00000000"
+ "00010000"};
+ auto FileOid = zen::Oid::FromHexString(ChunkId);
+
+ OpWriter.BeginArray("files");
+ OpWriter.BeginObject();
+ OpWriter << "id" << FileOid;
+ OpWriter << "clientpath"
+ << "/{engine}/client/side/path";
+ OpWriter << "serverpath" << BinPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.BeginObject();
+ OpWriter << "id" << LongPathFileOid;
+ OpWriter << "clientpath" << LongClientPath;
+ OpWriter << "serverpath" << LongRelPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.EndArray();
+
+ zen::CbObject Op = OpWriter.Save();
+
+ zen::CbPackage OpPackage(Op);
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << ChunkId << "?offset=1&size=10";
- auto Response = Http.Get(ChunkGetUri);
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
- REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
- CHECK(Response.ResponsePayload.GetSize() == 10);
- }
+ HttpClient Http{OplogUri};
- // Read long-path file data
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << LongPathChunkId;
- auto Response = Http.Get(ChunkGetUri);
+ {
+ auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView()));
- REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
- CHECK(Response.ResponsePayload.GetSize() == sizeof(LongPathFileData));
- }
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::Created);
+ }
+
+ // Read file data, it is raw and uncompressed
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << ChunkId;
+ auto Response = Http.Get(ChunkGetUri);
- ZEN_INFO("+++++++");
+ REQUIRE(Response);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+
+ IoBuffer Data = Response.ResponsePayload;
+ IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath);
+ CHECK(ReferenceData.GetSize() == Data.GetSize());
+ CHECK(ReferenceData.GetView().EqualBytes(Data.GetView()));
}
- SUBCASE("snapshot")
+ // Read long-path file data, it is raw and uncompressed
{
- zen::CbObjectWriter OpWriter;
- OpWriter << "key"
- << "foo";
-
- const std::string_view ChunkId{
- "00000000"
- "00000000"
- "00010000"};
- auto FileOid = zen::Oid::FromHexString(ChunkId);
-
- OpWriter.BeginArray("files");
- OpWriter.BeginObject();
- OpWriter << "id" << FileOid;
- OpWriter << "clientpath"
- << "/{engine}/client/side/path";
- OpWriter << "serverpath" << BinPath.c_str();
- OpWriter.EndObject();
- OpWriter.BeginObject();
- OpWriter << "id" << LongPathFileOid;
- OpWriter << "clientpath" << LongClientPath;
- OpWriter << "serverpath" << LongRelPath.c_str();
- OpWriter.EndObject();
- OpWriter.EndArray();
-
- zen::CbObject Op = OpWriter.Save();
-
- zen::CbPackage OpPackage(Op);
-
- zen::BinaryWriter MemOut;
- legacy::SaveCbPackage(OpPackage, MemOut);
-
- HttpClient Http{BaseUri};
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << LongPathChunkId;
+ auto Response = Http.Get(ChunkGetUri);
- {
- auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView()));
+ REQUIRE(Response);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
- REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::Created);
- }
+ IoBuffer Data = Response.ResponsePayload;
+ MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)};
+ CHECK(Data.GetSize() == sizeof(LongPathFileData));
+ CHECK(Data.GetView().EqualBytes(ExpectedView));
+ }
- // Read file data, it is raw and uncompressed
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << ChunkId;
- auto Response = Http.Get(ChunkGetUri);
+ {
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); });
+ auto Response = Http.Post("/rpc"sv, Payload);
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::OK);
+ }
- REQUIRE(Response);
- REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+ // Read chunk data, it is now compressed
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << ChunkId;
+ auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
- IoBuffer Data = Response.ResponsePayload;
- IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath);
- CHECK(ReferenceData.GetSize() == Data.GetSize());
- CHECK(ReferenceData.GetView().EqualBytes(Data.GetView()));
- }
+ REQUIRE(Response);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+
+ IoBuffer Data = Response.ResponsePayload;
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
+ REQUIRE(Compressed);
+ IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
+ IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath);
+ CHECK(RawSize == ReferenceData.GetSize());
+ CHECK(ReferenceData.GetSize() == DataDecompressed.GetSize());
+ CHECK(ReferenceData.GetView().EqualBytes(DataDecompressed.GetView()));
+ }
- // Read long-path file data, it is raw and uncompressed
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << LongPathChunkId;
- auto Response = Http.Get(ChunkGetUri);
+ // Read compressed long-path file data after snapshot
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << LongPathChunkId;
+ auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
- REQUIRE(Response);
- REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+ REQUIRE(Response);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+
+ IoBuffer Data = Response.ResponsePayload;
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
+ REQUIRE(Compressed);
+ IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
+ MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)};
+ CHECK(RawSize == sizeof(LongPathFileData));
+ CHECK(DataDecompressed.GetSize() == sizeof(LongPathFileData));
+ CHECK(DataDecompressed.GetView().EqualBytes(ExpectedView));
+ }
- IoBuffer Data = Response.ResponsePayload;
- MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)};
- CHECK(Data.GetSize() == sizeof(LongPathFileData));
- CHECK(Data.GetView().EqualBytes(ExpectedView));
- }
+ ZEN_INFO("+++++++");
+ }
- {
- IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); });
- auto Response = Http.Post("/rpc"sv, Payload);
- REQUIRE(Response);
- CHECK(Response.StatusCode == HttpResponseCode::OK);
- }
+ // --- snapshot zero byte file ---
+ {
+ std::string OplogUri = CreateProjectAndOplog("test_zero", "oplog_zero");
- // Read chunk data, it is now compressed
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << ChunkId;
- auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
+ std::filesystem::path EmptyFileRelPath = std::filesystem::path("zerobyte_snapshot_test") / "empty.bin";
+ std::filesystem::path EmptyFileAbsPath = RootPath / EmptyFileRelPath;
+ CreateDirectories(MakeSafeAbsolutePath(EmptyFileAbsPath.parent_path()));
+ WriteFile(MakeSafeAbsolutePath(EmptyFileAbsPath), IoBuffer{});
+ REQUIRE(IsFile(MakeSafeAbsolutePath(EmptyFileAbsPath)));
- REQUIRE(Response);
- REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+ const std::string_view EmptyChunkId{
+ "00000000"
+ "00000000"
+ "00030000"};
+ auto EmptyFileOid = zen::Oid::FromHexString(EmptyChunkId);
+
+ zen::CbObjectWriter OpWriter;
+ OpWriter << "key"
+ << "zero_byte_test";
+ OpWriter.BeginArray("files");
+ OpWriter.BeginObject();
+ OpWriter << "id" << EmptyFileOid;
+ OpWriter << "clientpath"
+ << "/{engine}/empty_file";
+ OpWriter << "serverpath" << EmptyFileRelPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.EndArray();
+
+ zen::CbObject Op = OpWriter.Save();
+ zen::CbPackage OpPackage(Op);
- IoBuffer Data = Response.ResponsePayload;
- IoHash RawHash;
- uint64_t RawSize;
- CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
- REQUIRE(Compressed);
- IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
- IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath);
- CHECK(RawSize == ReferenceData.GetSize());
- CHECK(ReferenceData.GetSize() == DataDecompressed.GetSize());
- CHECK(ReferenceData.GetView().EqualBytes(DataDecompressed.GetView()));
- }
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
- // Read compressed long-path file data after snapshot
- {
- zen::StringBuilder<128> ChunkGetUri;
- ChunkGetUri << "/" << LongPathChunkId;
- auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
+ HttpClient Http{OplogUri};
- REQUIRE(Response);
- REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+ {
+ auto Response = Http.Post("/new", IoBufferBuilder::MakeFromMemory(MemOut.GetView()));
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::Created);
+ }
- IoBuffer Data = Response.ResponsePayload;
- IoHash RawHash;
- uint64_t RawSize;
- CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
- REQUIRE(Compressed);
- IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
- MemoryView ExpectedView{LongPathFileData, sizeof(LongPathFileData)};
- CHECK(RawSize == sizeof(LongPathFileData));
- CHECK(DataDecompressed.GetSize() == sizeof(LongPathFileData));
- CHECK(DataDecompressed.GetView().EqualBytes(ExpectedView));
- }
+ // Read file data before snapshot - raw and uncompressed, 0 bytes.
+ // http.sys converts a 200 OK with empty body to 204 No Content, so
+ // accept either status code.
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << EmptyChunkId;
+ auto Response = Http.Get(ChunkGetUri);
- ZEN_INFO("+++++++");
+ REQUIRE(Response);
+ CHECK((Response.StatusCode == HttpResponseCode::OK || Response.StatusCode == HttpResponseCode::NoContent));
+ CHECK(Response.ResponsePayload.GetSize() == 0);
}
- SUBCASE("test chunk not found error")
+ // Trigger snapshot.
{
- HttpClient Http{BaseUri};
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { Writer.AddString("method"sv, "snapshot"sv); });
+ auto Response = Http.Post("/rpc"sv, Payload);
+ REQUIRE(Response);
+ CHECK(Response.StatusCode == HttpResponseCode::OK);
+ }
- for (size_t I = 0; I < 65; I++)
- {
- zen::StringBuilder<128> PostUri;
- PostUri << "/f77c781846caead318084604/info";
- auto Response = Http.Get(PostUri);
+ // Read chunk after snapshot - compressed, decompresses to 0 bytes.
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << "/" << EmptyChunkId;
+ auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}});
- REQUIRE(!Response.Error);
- CHECK(Response.StatusCode == HttpResponseCode::NotFound);
- }
+ REQUIRE(Response);
+ REQUIRE(Response.StatusCode == HttpResponseCode::OK);
+
+ IoBuffer Data = Response.ResponsePayload;
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize);
+ REQUIRE(Compressed);
+ CHECK(RawSize == 0);
+ IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer();
+ CHECK(DataDecompressed.GetSize() == 0);
}
- // Cleanup long-path test directory
{
std::error_code Ec;
- DeleteDirectories(MakeSafeAbsolutePath(RootPath / "longpathtest"), Ec);
+ DeleteDirectories(MakeSafeAbsolutePath(RootPath / "zerobyte_snapshot_test"), Ec);
}
+
+ ZEN_INFO("+++++++");
+ }
+
+ // --- test chunk not found error ---
+ {
+ std::string OplogUri = CreateProjectAndOplog("test_notfound", "oplog_notfound");
+ HttpClient Http{OplogUri};
+
+ for (size_t I = 0; I < 65; I++)
+ {
+ zen::StringBuilder<128> PostUri;
+ PostUri << "/f77c781846caead318084604/info";
+ auto Response = Http.Get(PostUri);
+
+ REQUIRE(!Response.Error);
+ CHECK(Response.StatusCode == HttpResponseCode::NotFound);
+ }
+ }
+
+ // Cleanup long-path test directory
+ {
+ std::error_code Ec;
+ DeleteDirectories(MakeSafeAbsolutePath(RootPath / "longpathtest"), Ec);
}
}
@@ -656,86 +760,102 @@ TEST_CASE("project.remote")
}
};
- SUBCASE("File")
+ // --- Zen ---
+ // NOTE: Zen export must run before file-based exports from the same source
+ // oplog. A prior file export leaves server-side state that causes a
+ // subsequent zen-protocol export from the same oplog to abort.
{
+ INFO("Zen");
ScopedTemporaryDirectory TempDir;
{
- IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) {
+ std::string ExportSourceUri = Servers.GetInstance(0).GetBaseUri();
+ std::string ExportTargetUri = Servers.GetInstance(1).GetBaseUri();
+ MakeProject(ExportTargetUri, "proj0_zen");
+ MakeOplog(ExportTargetUri, "proj0_zen", "oplog0_zen");
+
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
Writer << "method"sv
<< "export"sv;
Writer << "params" << BeginObject;
{
Writer << "maxblocksize"sv << 3072u;
Writer << "maxchunkembedsize"sv << 1296u;
- Writer << "chunkfilesizelimit"sv << 5u * 1024u;
Writer << "maxchunksperblock"sv << 16u;
+ Writer << "chunkfilesizelimit"sv << 5u * 1024u;
Writer << "force"sv << false;
- Writer << "file"sv << BeginObject;
+ Writer << "zen"sv << BeginObject;
{
- Writer << "path"sv << path;
- Writer << "name"sv
- << "proj0_oplog0"sv;
+ Writer << "url"sv << ExportTargetUri.substr(7);
+ Writer << "project"
+ << "proj0_zen";
+ Writer << "oplog"
+ << "oplog0_zen";
}
- Writer << EndObject; // "file"
+ Writer << EndObject; // "zen"
}
Writer << EndObject; // "params"
});
- HttpClient Http{Servers.GetInstance(0).GetBaseUri()};
-
+ HttpClient Http{Servers.GetInstance(0).GetBaseUri()};
HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0", "oplog0"), Payload);
HttpWaitForCompletion(Servers.GetInstance(0), Response);
}
+ ValidateAttachments(1, "proj0_zen", "oplog0_zen");
+ ValidateOplog(1, "proj0_zen", "oplog0_zen");
+
{
- MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_copy");
- MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ std::string ImportSourceUri = Servers.GetInstance(1).GetBaseUri();
+ std::string ImportTargetUri = Servers.GetInstance(2).GetBaseUri();
+ MakeProject(ImportTargetUri, "proj1");
+ MakeOplog(ImportTargetUri, "proj1", "oplog1");
- IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) {
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
Writer << "method"sv
<< "import"sv;
Writer << "params" << BeginObject;
{
Writer << "force"sv << false;
- Writer << "file"sv << BeginObject;
+ Writer << "zen"sv << BeginObject;
{
- Writer << "path"sv << path;
- Writer << "name"sv
- << "proj0_oplog0"sv;
+ Writer << "url"sv << ImportSourceUri.substr(7);
+ Writer << "project"
+ << "proj0_zen";
+ Writer << "oplog"
+ << "oplog0_zen";
}
- Writer << EndObject; // "file"
+ Writer << EndObject; // "zen"
}
Writer << EndObject; // "params"
});
- HttpClient Http{Servers.GetInstance(1).GetBaseUri()};
-
- HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_copy", "oplog0_copy"), Payload);
- HttpWaitForCompletion(Servers.GetInstance(1), Response);
+ HttpClient Http{Servers.GetInstance(2).GetBaseUri()};
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj1", "oplog1"), Payload);
+ HttpWaitForCompletion(Servers.GetInstance(2), Response);
}
- ValidateAttachments(1, "proj0_copy", "oplog0_copy");
- ValidateOplog(1, "proj0_copy", "oplog0_copy");
+ ValidateAttachments(2, "proj1", "oplog1");
+ ValidateOplog(2, "proj1", "oplog1");
}
- SUBCASE("File disable blocks")
+ // --- File ---
{
+ INFO("File");
ScopedTemporaryDirectory TempDir;
{
- IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) {
Writer << "method"sv
<< "export"sv;
Writer << "params" << BeginObject;
{
Writer << "maxblocksize"sv << 3072u;
Writer << "maxchunkembedsize"sv << 1296u;
- Writer << "maxchunksperblock"sv << 16u;
Writer << "chunkfilesizelimit"sv << 5u * 1024u;
- Writer << "force"sv << false;
+ Writer << "maxchunksperblock"sv << 16u;
+ Writer << "force"sv << true;
Writer << "file"sv << BeginObject;
{
- Writer << "path"sv << TempDir.Path().string();
+ Writer << "path"sv << path;
Writer << "name"sv
<< "proj0_oplog0"sv;
- Writer << "disableblocks"sv << true;
}
Writer << EndObject; // "file"
}
@@ -748,9 +868,10 @@ TEST_CASE("project.remote")
HttpWaitForCompletion(Servers.GetInstance(0), Response);
}
{
- MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_copy");
- MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
- IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_file");
+ MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_file", "oplog0_file");
+
+ IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) {
Writer << "method"sv
<< "import"sv;
Writer << "params" << BeginObject;
@@ -758,7 +879,7 @@ TEST_CASE("project.remote")
Writer << "force"sv << false;
Writer << "file"sv << BeginObject;
{
- Writer << "path"sv << TempDir.Path().string();
+ Writer << "path"sv << path;
Writer << "name"sv
<< "proj0_oplog0"sv;
}
@@ -769,15 +890,16 @@ TEST_CASE("project.remote")
HttpClient Http{Servers.GetInstance(1).GetBaseUri()};
- HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_copy", "oplog0_copy"), Payload);
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_file", "oplog0_file"), Payload);
HttpWaitForCompletion(Servers.GetInstance(1), Response);
}
- ValidateAttachments(1, "proj0_copy", "oplog0_copy");
- ValidateOplog(1, "proj0_copy", "oplog0_copy");
+ ValidateAttachments(1, "proj0_file", "oplog0_file");
+ ValidateOplog(1, "proj0_file", "oplog0_file");
}
- SUBCASE("File force temp blocks")
+ // --- File disable blocks ---
{
+ INFO("File disable blocks");
ScopedTemporaryDirectory TempDir;
{
IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
@@ -789,26 +911,27 @@ TEST_CASE("project.remote")
Writer << "maxchunkembedsize"sv << 1296u;
Writer << "maxchunksperblock"sv << 16u;
Writer << "chunkfilesizelimit"sv << 5u * 1024u;
- Writer << "force"sv << false;
+ Writer << "force"sv << true;
Writer << "file"sv << BeginObject;
{
Writer << "path"sv << TempDir.Path().string();
Writer << "name"sv
<< "proj0_oplog0"sv;
- Writer << "enabletempblocks"sv << true;
+ Writer << "disableblocks"sv << true;
}
Writer << EndObject; // "file"
}
Writer << EndObject; // "params"
});
- HttpClient Http{Servers.GetInstance(0).GetBaseUri()};
+ HttpClient Http{Servers.GetInstance(0).GetBaseUri()};
+
HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0", "oplog0"), Payload);
HttpWaitForCompletion(Servers.GetInstance(0), Response);
}
{
- MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_copy");
- MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_noblock");
+ MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_noblock", "oplog0_noblock");
IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
Writer << "method"sv
<< "import"sv;
@@ -826,23 +949,20 @@ TEST_CASE("project.remote")
Writer << EndObject; // "params"
});
- HttpClient Http{Servers.GetInstance(1).GetBaseUri()};
- HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_copy", "oplog0_copy"), Payload);
+ HttpClient Http{Servers.GetInstance(1).GetBaseUri()};
+
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_noblock", "oplog0_noblock"), Payload);
HttpWaitForCompletion(Servers.GetInstance(1), Response);
}
- ValidateAttachments(1, "proj0_copy", "oplog0_copy");
- ValidateOplog(1, "proj0_copy", "oplog0_copy");
+ ValidateAttachments(1, "proj0_noblock", "oplog0_noblock");
+ ValidateOplog(1, "proj0_noblock", "oplog0_noblock");
}
- SUBCASE("Zen")
+ // --- File force temp blocks ---
{
+ INFO("File force temp blocks");
ScopedTemporaryDirectory TempDir;
{
- std::string ExportSourceUri = Servers.GetInstance(0).GetBaseUri();
- std::string ExportTargetUri = Servers.GetInstance(1).GetBaseUri();
- MakeProject(ExportTargetUri, "proj0_copy");
- MakeOplog(ExportTargetUri, "proj0_copy", "oplog0_copy");
-
IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
Writer << "method"sv
<< "export"sv;
@@ -852,14 +972,13 @@ TEST_CASE("project.remote")
Writer << "maxchunkembedsize"sv << 1296u;
Writer << "maxchunksperblock"sv << 16u;
Writer << "chunkfilesizelimit"sv << 5u * 1024u;
- Writer << "force"sv << false;
- Writer << "zen"sv << BeginObject;
+ Writer << "force"sv << true;
+ Writer << "file"sv << BeginObject;
{
- Writer << "url"sv << ExportTargetUri.substr(7);
- Writer << "project"
- << "proj0_copy";
- Writer << "oplog"
- << "oplog0_copy";
+ Writer << "path"sv << TempDir.Path().string();
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ Writer << "enabletempblocks"sv << true;
}
Writer << EndObject; // "file"
}
@@ -870,40 +989,32 @@ TEST_CASE("project.remote")
HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0", "oplog0"), Payload);
HttpWaitForCompletion(Servers.GetInstance(0), Response);
}
- ValidateAttachments(1, "proj0_copy", "oplog0_copy");
- ValidateOplog(1, "proj0_copy", "oplog0_copy");
-
{
- std::string ImportSourceUri = Servers.GetInstance(1).GetBaseUri();
- std::string ImportTargetUri = Servers.GetInstance(2).GetBaseUri();
- MakeProject(ImportTargetUri, "proj1");
- MakeOplog(ImportTargetUri, "proj1", "oplog1");
-
+ MakeProject(Servers.GetInstance(1).GetBaseUri(), "proj0_tmpblock");
+ MakeOplog(Servers.GetInstance(1).GetBaseUri(), "proj0_tmpblock", "oplog0_tmpblock");
IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
Writer << "method"sv
<< "import"sv;
Writer << "params" << BeginObject;
{
Writer << "force"sv << false;
- Writer << "zen"sv << BeginObject;
+ Writer << "file"sv << BeginObject;
{
- Writer << "url"sv << ImportSourceUri.substr(7);
- Writer << "project"
- << "proj0_copy";
- Writer << "oplog"
- << "oplog0_copy";
+ Writer << "path"sv << TempDir.Path().string();
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
}
Writer << EndObject; // "file"
}
Writer << EndObject; // "params"
});
- HttpClient Http{Servers.GetInstance(2).GetBaseUri()};
- HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj1", "oplog1"), Payload);
- HttpWaitForCompletion(Servers.GetInstance(2), Response);
+ HttpClient Http{Servers.GetInstance(1).GetBaseUri()};
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", "proj0_tmpblock", "oplog0_tmpblock"), Payload);
+ HttpWaitForCompletion(Servers.GetInstance(1), Response);
}
- ValidateAttachments(2, "proj1", "oplog1");
- ValidateOplog(2, "proj1", "oplog1");
+ ValidateAttachments(1, "proj0_tmpblock", "oplog0_tmpblock");
+ ValidateOplog(1, "proj0_tmpblock", "oplog0_tmpblock");
}
}
@@ -1154,6 +1265,368 @@ TEST_CASE("project.rpcappendop")
}
}
+TEST_CASE("project.file.data.transitions")
+{
+ using namespace utils;
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+ const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady();
+
+ zen::StringBuilder<64> ServerBaseUri;
+ ServerBaseUri << fmt::format("http://localhost:{}", PortNumber);
+
+ // Set up a root directory with a test file on disk for path-referenced serving
+ std::filesystem::path RootDir = TestDir / "root";
+ std::filesystem::path TestFilePath = RootDir / "content" / "testfile.bin";
+ std::filesystem::path RelServerPath = std::filesystem::path("content") / "testfile.bin";
+ CreateDirectories(TestFilePath.parent_path());
+ IoBuffer FileBlob = CreateRandomBlob(4096);
+ WriteFile(TestFilePath, FileBlob);
+
+ // Create a compressed blob to use as a CAS-referenced attachment (content differs from FileBlob)
+ CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(CreateRandomBlob(2048)));
+
+ // Fixed chunk IDs for the file entry across sub-tests
+ const std::string_view FileChunkIdStr{
+ "aa000000"
+ "bb000000"
+ "cc000001"};
+ Oid FileOid = Oid::FromHexString(FileChunkIdStr);
+
+ HttpClient Http{ServerBaseUri};
+
+ auto MakeProject = [&](std::string_view ProjectName) {
+ CbObjectWriter Project;
+ Project.AddString("id"sv, ProjectName);
+ Project.AddString("root"sv, PathToUtf8(RootDir.c_str()));
+ Project.AddString("engine"sv, ""sv);
+ Project.AddString("project"sv, ""sv);
+ Project.AddString("projectfile"sv, ""sv);
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), Project.Save());
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeProject"));
+ };
+
+ auto MakeOplog = [&](std::string_view ProjectName, std::string_view OplogName) {
+ HttpClient::Response Response =
+ Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}, ZenContentType::kCbObject);
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("MakeOplog"));
+ };
+
+ auto PostOplogEntry = [&](std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) {
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
+ IoBuffer Body{IoBuffer::Wrap, MemOut.GetData(), MemOut.GetSize()};
+ Body.SetContentType(HttpContentType::kCbPackage);
+ HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body);
+ REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("PostOplogEntry"));
+ };
+
+ auto GetChunk = [&](std::string_view ProjectName) -> HttpClient::Response {
+ return Http.Get(fmt::format("/prj/{}/oplog/oplog/{}", ProjectName, FileChunkIdStr));
+ };
+
+ // Extract the raw decompressed bytes from a chunk response, handling both compressed and uncompressed payloads
+ auto GetDecompressedPayload = [](const HttpClient::Response& Response) -> IoBuffer {
+ if (Response.ResponsePayload.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Response.ResponsePayload), RawHash, RawSize);
+ REQUIRE(Compressed);
+ return Compressed.Decompress().AsIoBuffer();
+ }
+ return Response.ResponsePayload;
+ };
+
+ auto TriggerGcAndWait = [&]() {
+ HttpClient::Response TriggerResponse = Http.Post("/admin/gc?smallobjects=true"sv, IoBuffer{});
+ REQUIRE_MESSAGE(TriggerResponse.IsSuccess(), TriggerResponse.ErrorMessage("TriggerGc"));
+
+ for (int Attempt = 0; Attempt < 100; ++Attempt)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ HttpClient::Response StatusResponse = Http.Get("/admin/gc"sv);
+ REQUIRE_MESSAGE(StatusResponse.IsSuccess(), StatusResponse.ErrorMessage("GcStatus"));
+ CbObject StatusObj = StatusResponse.AsObject();
+ if (StatusObj["Status"sv].AsString() == "Idle"sv)
+ {
+ return;
+ }
+ }
+ FAIL("GC did not complete within timeout");
+ };
+
+ auto BuildPathReferencedFileOp = [&](const Oid& KeyId) -> CbPackage {
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(KeyId);
+ Object.BeginArray("files"sv);
+ Object.BeginObject();
+ Object << "id"sv << FileOid;
+ Object << "serverpath"sv << RelServerPath.string();
+ Object << "clientpath"sv
+ << "/{engine}/testfile.bin"sv;
+ Object.EndObject();
+ Object.EndArray();
+ Package.SetObject(Object.Save());
+ return Package;
+ };
+
+ auto BuildHashReferencedFileOp = [&](const Oid& KeyId, const CompressedBuffer& Blob) -> CbPackage {
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(KeyId);
+ CbAttachment Attach(Blob, Blob.DecodeRawHash());
+ Object.BeginArray("files"sv);
+ Object.BeginObject();
+ Object << "id"sv << FileOid;
+ Object << "data"sv << Attach;
+ Object << "clientpath"sv
+ << "/{engine}/testfile.bin"sv;
+ Object.EndObject();
+ Object.EndArray();
+ Package.AddAttachment(Attach);
+ Package.SetObject(Object.Save());
+ return Package;
+ };
+
+ // --- path-referenced file is retrievable ---
+ {
+ MakeProject("proj_path"sv);
+ MakeOplog("proj_path"sv, "oplog"sv);
+
+ CbPackage Op = BuildPathReferencedFileOp(Oid::NewOid());
+ PostOplogEntry("proj_path"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_path"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ // --- hash-referenced file is retrievable ---
+ {
+ MakeProject("proj_hash"sv);
+ MakeOplog("proj_hash"sv, "oplog"sv);
+
+ CbPackage Op = BuildHashReferencedFileOp(Oid::NewOid(), CompressedBlob);
+ PostOplogEntry("proj_hash"sv, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk("proj_hash"sv);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer();
+ CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize());
+ CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView()));
+ }
+ }
+
+ struct TransitionVariant
+ {
+ std::string_view Suffix;
+ bool SameOpKey;
+ bool RunGc;
+ };
+
+ static constexpr TransitionVariant Variants[] = {
+ {"_nk", false, false},
+ {"_sk", true, false},
+ {"_nk_gc", false, true},
+ {"_sk_gc", true, true},
+ };
+
+ // --- hash-referenced to path-referenced transition with different content ---
+ for (const TransitionVariant& V : Variants)
+ {
+ std::string ProjName = fmt::format("proj_h2pd{}", V.Suffix);
+ MakeProject(ProjName);
+ MakeOplog(ProjName, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid();
+
+ {
+ CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, CompressedBlob);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer();
+ CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView()));
+ }
+ }
+
+ {
+ CbPackage Op = BuildPathReferencedFileOp(SecondOpKey);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+ }
+
+ if (V.RunGc)
+ {
+ TriggerGcAndWait();
+ }
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ // --- hash-referenced to path-referenced transition with identical content ---
+ {
+ CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView()));
+
+ for (const TransitionVariant& V : Variants)
+ {
+ std::string ProjName = fmt::format("proj_h2ps{}", V.Suffix);
+ MakeProject(ProjName);
+ MakeOplog(ProjName, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid();
+
+ {
+ CbPackage Op = BuildHashReferencedFileOp(FirstOpKey, MatchingBlob);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ {
+ CbPackage Op = BuildPathReferencedFileOp(SecondOpKey);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+ }
+
+ if (V.RunGc)
+ {
+ TriggerGcAndWait();
+ }
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+ }
+
+ // --- path-referenced to hash-referenced transition with different content ---
+ for (const TransitionVariant& V : Variants)
+ {
+ std::string ProjName = fmt::format("proj_p2hd{}", V.Suffix);
+ MakeProject(ProjName);
+ MakeOplog(ProjName, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid();
+
+ {
+ CbPackage Op = BuildPathReferencedFileOp(FirstOpKey);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ {
+ CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, CompressedBlob);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+ }
+
+ if (V.RunGc)
+ {
+ TriggerGcAndWait();
+ }
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ IoBuffer ExpectedDecompressed = CompressedBlob.Decompress().AsIoBuffer();
+ CHECK_EQ(Payload.GetSize(), ExpectedDecompressed.GetSize());
+ CHECK(Payload.GetView().EqualBytes(ExpectedDecompressed.GetView()));
+ }
+ }
+
+ // --- path-referenced to hash-referenced transition with identical content ---
+ {
+ CompressedBuffer MatchingBlob = CompressedBuffer::Compress(SharedBuffer::Clone(FileBlob.GetView()));
+
+ for (const TransitionVariant& V : Variants)
+ {
+ std::string ProjName = fmt::format("proj_p2hs{}", V.Suffix);
+ MakeProject(ProjName);
+ MakeOplog(ProjName, "oplog"sv);
+
+ Oid FirstOpKey = Oid::NewOid();
+ Oid SecondOpKey = V.SameOpKey ? FirstOpKey : Oid::NewOid();
+
+ {
+ CbPackage Op = BuildPathReferencedFileOp(FirstOpKey);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk first op"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+
+ {
+ CbPackage Op = BuildHashReferencedFileOp(SecondOpKey, MatchingBlob);
+ PostOplogEntry(ProjName, "oplog"sv, Op);
+ }
+
+ if (V.RunGc)
+ {
+ TriggerGcAndWait();
+ }
+
+ HttpClient::Response Response = GetChunk(ProjName);
+ CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("GetChunk after transition"));
+ if (Response.IsSuccess())
+ {
+ IoBuffer Payload = GetDecompressedPayload(Response);
+ CHECK_EQ(Payload.GetSize(), FileBlob.GetSize());
+ CHECK(Payload.GetView().EqualBytes(FileBlob.GetView()));
+ }
+ }
+ }
+}
+
TEST_SUITE_END();
} // namespace zen::tests
diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua
index 7b208bbc7..c240712ea 100644
--- a/src/zenserver-test/xmake.lua
+++ b/src/zenserver-test/xmake.lua
@@ -9,7 +9,7 @@ target("zenserver-test")
add_deps("zencore", "zenremotestore", "zenhttp", "zencompute", "zenstore")
add_deps("zenserver", {inherit=false})
add_deps("zentest-appstub", {inherit=false})
- add_packages("http_parser")
+ add_packages("llhttp")
if has_config("zennomad") then
add_deps("zennomad")
diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp
index cf7ffe4e4..d713f693f 100644
--- a/src/zenserver-test/zenserver-test.cpp
+++ b/src/zenserver-test/zenserver-test.cpp
@@ -199,7 +199,7 @@ TEST_CASE("default.single")
HttpClient Http{fmt::format("http://localhost:{}", PortNumber)};
- for (int i = 0; i < 100; ++i)
+ for (int i = 0; i < 20; ++i)
{
auto res = Http.Get("/test/hello"sv);
++RequestCount;
@@ -238,7 +238,6 @@ TEST_CASE("default.loopback")
ZEN_INFO("Running loopback server test...");
- SUBCASE("ipv4 endpoint connectivity")
{
HttpClient Http{fmt::format("http://127.0.0.1:{}", PortNumber)};
@@ -247,7 +246,6 @@ TEST_CASE("default.loopback")
CHECK(res);
}
- SUBCASE("ipv6 endpoint connectivity")
{
HttpClient Http{fmt::format("http://[::1]:{}", PortNumber)};
@@ -287,7 +285,7 @@ TEST_CASE("multi.basic")
HttpClient Http{fmt::format("http://localhost:{}", PortNumber)};
- for (int i = 0; i < 100; ++i)
+ for (int i = 0; i < 20; ++i)
{
auto res = Http.Get("/test/hello"sv);
++RequestCount;
@@ -401,13 +399,11 @@ TEST_CASE("http.unixsocket")
Settings.UnixSocketPath = SocketPath;
HttpClient Http{fmt::format("http://localhost:{}", PortNumber), Settings, {}};
- SUBCASE("GET over unix socket")
{
HttpClient::Response Res = Http.Get("/testing/hello");
CHECK(Res.IsSuccess());
}
- SUBCASE("POST echo over unix socket")
{
IoBuffer Body{IoBuffer::Wrap, "unix-test", 9};
HttpClient::Response Res = Http.Post("/testing/echo", Body);
@@ -431,13 +427,11 @@ TEST_CASE("http.nonetwork")
Settings.UnixSocketPath = SocketPath;
HttpClient Http{fmt::format("http://localhost:{}", PortNumber), Settings, {}};
- SUBCASE("GET over unix socket succeeds")
{
HttpClient::Response Res = Http.Get("/testing/hello");
CHECK(Res.IsSuccess());
}
- SUBCASE("TCP connection is refused")
{
asio::io_context IoContext;
asio::ip::tcp::socket Socket(IoContext);
diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp
index 1673cea6c..b110f7538 100644
--- a/src/zenserver/compute/computeserver.cpp
+++ b/src/zenserver/compute/computeserver.cpp
@@ -22,6 +22,8 @@
# if ZEN_WITH_HORDE
# include <zenhorde/hordeconfig.h>
# include <zenhorde/hordeprovisioner.h>
+# include <zenhttp/httpclientauth.h>
+# include <zenutil/authutils.h>
# endif
# if ZEN_WITH_NOMAD
# include <zennomad/nomadconfig.h>
@@ -67,6 +69,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("compute",
"",
+ "coordinator-session",
+ "Session ID of the orchestrator (for stale-instance rejection)",
+ cxxopts::value<std::string>(m_ServerOptions.CoordinatorSession)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "announce-url",
+ "Override URL announced to the coordinator (e.g. relay-visible endpoint)",
+ cxxopts::value<std::string>(m_ServerOptions.AnnounceUrl)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
"idms",
"Enable IDMS cloud detection; optionally specify a custom probe endpoint",
cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"),
@@ -79,6 +95,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"),
"");
+ Options.add_option("compute",
+ "",
+ "provision-clean",
+ "Pass --clean to provisioned worker instances so they wipe state on startup",
+ cxxopts::value<bool>(m_ServerOptions.ProvisionClean)->default_value("false"),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "provision-tracehost",
+ "Pass --tracehost to provisioned worker instances for remote trace collection",
+ cxxopts::value<std::string>(m_ServerOptions.ProvisionTraceHost)->default_value(""),
+ "");
+
# if ZEN_WITH_HORDE
// Horde provisioning options
Options.add_option("horde",
@@ -139,6 +169,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("horde",
"",
+ "horde-drain-grace-period",
+ "Grace period in seconds for draining agents before force-kill",
+ cxxopts::value<int>(m_ServerOptions.HordeConfig.DrainGracePeriodSeconds)->default_value("300"),
+ "");
+
+ Options.add_option("horde",
+ "",
"horde-host",
"Host address for Horde agents to connect back to",
cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""),
@@ -164,6 +201,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
"Port number for Zen service communication",
cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"),
"");
+
+ Options.add_option("horde",
+ "",
+ "horde-oidctoken-exe-path",
+ "Path to OidcToken executable for automatic Horde authentication",
+ cxxopts::value<std::string>(m_HordeOidcTokenExePath)->default_value(""),
+ "");
# endif
# if ZEN_WITH_NOMAD
@@ -313,6 +357,30 @@ ZenComputeServerConfigurator::ValidateOptions()
# if ZEN_WITH_HORDE
horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr);
horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr);
+
+ // Set up OidcToken-based authentication if no static token was provided
+ if (m_ServerOptions.HordeConfig.AuthToken.empty() && !m_ServerOptions.HordeConfig.ServerUrl.empty())
+ {
+ std::filesystem::path OidcExePath = FindOidcTokenExePath(m_HordeOidcTokenExePath);
+ if (!OidcExePath.empty())
+ {
+ ZEN_INFO("using OidcToken executable for Horde authentication: {}", OidcExePath);
+ auto Provider = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath,
+ m_ServerOptions.HordeConfig.ServerUrl,
+ /*Quiet=*/true,
+ /*Unattended=*/false,
+ /*Hidden=*/true,
+ /*IsHordeUrl=*/true);
+ if (Provider)
+ {
+ m_ServerOptions.HordeConfig.AccessTokenProvider = std::move(*Provider);
+ }
+ else
+ {
+ ZEN_WARN("OidcToken authentication failed; Horde requests will be unauthenticated");
+ }
+ }
+ }
# endif
# if ZEN_WITH_NOMAD
@@ -347,6 +415,8 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ
}
m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint;
+ m_CoordinatorSession = ServerConfig.CoordinatorSession;
+ m_AnnounceUrl = ServerConfig.AnnounceUrl;
m_InstanceId = ServerConfig.InstanceId;
m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket;
@@ -379,13 +449,20 @@ ZenComputeServer::Cleanup()
m_AnnounceTimer.cancel();
# if ZEN_WITH_HORDE
- // Shut down Horde provisioner first — this signals all agent threads
+ // Disconnect the provisioner state provider before destroying the
+ // provisioner so the orchestrator HTTP layer cannot call into it.
+ if (m_OrchestratorService)
+ {
+ m_OrchestratorService->SetProvisionerStateProvider(nullptr);
+ }
+
+ // Shut down Horde provisioner - this signals all agent threads
// to exit and joins them before we tear down HTTP services.
m_HordeProvisioner.reset();
# endif
# if ZEN_WITH_NOMAD
- // Shut down Nomad provisioner — stops the management thread and
+ // Shut down Nomad provisioner - stops the management thread and
// sends stop requests for all tracked jobs.
m_NomadProvisioner.reset();
# endif
@@ -419,12 +496,12 @@ ZenComputeServer::Cleanup()
m_IoRunner.join();
}
- ShutdownServices();
-
if (m_Http)
{
m_Http->Close();
}
+
+ ShutdownServices();
}
catch (const std::exception& Ex)
{
@@ -444,11 +521,12 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
ZEN_TRACE_CPU("ZenComputeServer::InitializeServices");
ZEN_INFO("initializing compute services");
- CidStoreConfiguration Config;
- Config.RootDirectory = m_DataRoot / "cas";
+ m_ActionStore = std::make_unique<MemoryCidStore>();
- m_CidStore = std::make_unique<CidStore>(m_GcManager);
- m_CidStore->Initialize(Config);
+ CidStoreConfiguration WorkerStoreConfig;
+ WorkerStoreConfig.RootDirectory = m_DataRoot / "cas";
+ m_WorkerStore = std::make_unique<CidStore>(m_GcManager);
+ m_WorkerStore->Initialize(WorkerStoreConfig);
if (!ServerConfig.IdmsEndpoint.empty())
{
@@ -476,10 +554,12 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
std::make_unique<zen::compute::HttpOrchestratorService>(ServerConfig.DataDir / "orch", ServerConfig.EnableWorkerWebSocket);
ZEN_INFO("instantiating function service");
- m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_CidStore,
+ m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_ActionStore,
+ *m_WorkerStore,
m_StatsService,
ServerConfig.DataDir / "functions",
ServerConfig.MaxConcurrentActions);
+ m_ComputeService->SetShutdownCallback([this] { RequestExit(0); });
m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService);
@@ -504,7 +584,11 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
OrchestratorEndpoint << '/';
}
- m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint);
+ m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg,
+ OrchestratorEndpoint,
+ m_OrchestratorService->GetSessionId().ToString(),
+ ServerConfig.ProvisionClean,
+ ServerConfig.ProvisionTraceHost);
}
}
# endif
@@ -535,7 +619,14 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
: std::filesystem::path(HordeConfig.BinariesPath);
std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde";
- m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint);
+ m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig,
+ BinariesPath,
+ WorkingDir,
+ OrchestratorEndpoint,
+ m_OrchestratorService->GetSessionId().ToString(),
+ ServerConfig.ProvisionClean,
+ ServerConfig.ProvisionTraceHost);
+ m_OrchestratorService->SetProvisionerStateProvider(m_HordeProvisioner.get());
}
}
# endif
@@ -563,6 +654,10 @@ ZenComputeServer::GetInstanceId() const
std::string
ZenComputeServer::GetAnnounceUrl() const
{
+ if (!m_AnnounceUrl.empty())
+ {
+ return m_AnnounceUrl;
+ }
return m_Http->GetServiceUri(nullptr);
}
@@ -633,6 +728,11 @@ ZenComputeServer::BuildAnnounceBody()
<< "nomad";
}
+ if (!m_CoordinatorSession.empty())
+ {
+ AnnounceBody << "coordinator_session" << m_CoordinatorSession;
+ }
+
ResolveCloudMetadata();
if (m_CloudMetadata)
{
@@ -779,8 +879,10 @@ ZenComputeServer::ProvisionerMaintenanceTick()
# if ZEN_WITH_HORDE
if (m_HordeProvisioner)
{
- m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX);
+ // Re-apply current target to spawn agent threads for any that have
+ // exited since the last tick, without overwriting a user-set target.
auto Stats = m_HordeProvisioner->GetStats();
+ m_HordeProvisioner->SetTargetCoreCount(Stats.TargetCoreCount);
ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}",
Stats.TargetCoreCount,
Stats.EstimatedCoreCount,
@@ -882,12 +984,14 @@ ZenComputeServer::Run()
OnReady();
+ StartSelfSession("zencompute");
+
PostAnnounce();
EnqueueAnnounceTimer();
InitializeOrchestratorWebSocket();
# if ZEN_WITH_HORDE
- // Start Horde provisioning if configured — request maximum allowed cores.
+ // Start Horde provisioning if configured - request maximum allowed cores.
// SetTargetCoreCount clamps to HordeConfig::MaxCores internally.
if (m_HordeProvisioner)
{
@@ -899,7 +1003,7 @@ ZenComputeServer::Run()
# endif
# if ZEN_WITH_NOMAD
- // Start Nomad provisioning if configured — request maximum allowed cores.
+ // Start Nomad provisioning if configured - request maximum allowed cores.
// SetTargetCoreCount clamps to NomadConfig::MaxCores internally.
if (m_NomadProvisioner)
{
diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h
index 8f4edc0f0..aa9c1a5b3 100644
--- a/src/zenserver/compute/computeserver.h
+++ b/src/zenserver/compute/computeserver.h
@@ -10,6 +10,7 @@
# include <zencore/system.h>
# include <zenhttp/httpwsclient.h>
# include <zenstore/gc.h>
+# include <zenstore/memorycidstore.h>
# include "frontend/frontend.h"
namespace cxxopts {
@@ -41,7 +42,6 @@ class NomadProvisioner;
namespace zen {
-class CidStore;
class HttpApiService;
struct ZenComputeServerConfig : public ZenServerConfig
@@ -49,9 +49,13 @@ struct ZenComputeServerConfig : public ZenServerConfig
std::string UpstreamNotificationEndpoint;
std::string InstanceId; // For use in notifications
std::string CoordinatorEndpoint;
+ std::string CoordinatorSession; ///< Session ID for stale-instance rejection
+ std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint)
std::string IdmsEndpoint;
int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2)
- bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link
+ bool EnableWorkerWebSocket = false; // Use WebSocket for worker<->orchestrator link
+ bool ProvisionClean = false; // Pass --clean to provisioned workers
+ std::string ProvisionTraceHost; // Pass --tracehost to provisioned workers
# if ZEN_WITH_HORDE
horde::HordeConfig HordeConfig;
@@ -84,6 +88,7 @@ private:
# if ZEN_WITH_HORDE
std::string m_HordeModeStr = "direct";
std::string m_HordeEncryptionStr = "none";
+ std::string m_HordeOidcTokenExePath;
# endif
# if ZEN_WITH_NOMAD
@@ -131,7 +136,8 @@ public:
private:
GcManager m_GcManager;
GcScheduler m_GcScheduler{m_GcManager};
- std::unique_ptr<CidStore> m_CidStore;
+ std::unique_ptr<MemoryCidStore> m_ActionStore;
+ std::unique_ptr<CidStore> m_WorkerStore;
std::unique_ptr<HttpApiService> m_ApiService;
std::unique_ptr<zen::compute::HttpComputeService> m_ComputeService;
std::unique_ptr<zen::compute::HttpOrchestratorService> m_OrchestratorService;
@@ -146,6 +152,8 @@ private:
# endif
SystemMetricsTracker m_MetricsTracker;
std::string m_CoordinatorEndpoint;
+ std::string m_CoordinatorSession;
+ std::string m_AnnounceUrl;
std::string m_InstanceId;
asio::steady_timer m_AnnounceTimer{m_IoContext};
@@ -163,7 +171,7 @@ private:
std::string GetInstanceId() const;
CbObject BuildAnnounceBody();
- // Worker→orchestrator WebSocket client
+ // Worker->orchestrator WebSocket client
struct OrchestratorWsHandler : public IWsClientHandler
{
ZenComputeServer& Server;
diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp
index daad154bc..6449159fd 100644
--- a/src/zenserver/config/config.cpp
+++ b/src/zenserver/config/config.cpp
@@ -12,6 +12,7 @@
#include <zencore/compactbinaryutil.h>
#include <zencore/compactbinaryvalidation.h>
#include <zencore/except.h>
+#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
@@ -478,15 +479,27 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig
throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir));
}
- ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir);
- ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir);
- ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir);
- ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile);
- ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir);
+ SystemRootDir = ExpandEnvironmentVariables(SystemRootDir);
+ ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir);
+
+ DataDir = ExpandEnvironmentVariables(DataDir);
+ ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir);
+
+ ContentDir = ExpandEnvironmentVariables(ContentDir);
+ ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir);
+
+ ConfigFile = ExpandEnvironmentVariables(ConfigFile);
+ ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile);
+
+ BaseSnapshotDir = ExpandEnvironmentVariables(BaseSnapshotDir);
+ ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir);
+
+ ExpandEnvironmentVariables(SecurityConfigPath);
ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath);
if (!UnixSocketPath.empty())
{
+ UnixSocketPath = ExpandEnvironmentVariables(UnixSocketPath);
ServerOptions.HttpConfig.UnixSocketPath = MakeSafeAbsolutePath(UnixSocketPath);
}
diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp
index f3d8dbfe3..e1a8fed7d 100644
--- a/src/zenserver/diag/logging.cpp
+++ b/src/zenserver/diag/logging.cpp
@@ -112,7 +112,7 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService)
const zen::Oid ServerSessionId = zen::GetSessionId();
logging::Registry::Instance().ApplyAll([&](auto Logger) {
- static constinit logging::LogPoint SessionIdPoint{{}, logging::Info, "server session id: {}"};
+ static constinit logging::LogPoint SessionIdPoint{0, 0, logging::Info, "server session id: {}"};
ZEN_MEMSCOPE(ELLMTag::Logging);
Logger->Log(SessionIdPoint, fmt::make_format_args(ServerSessionId));
});
diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp
index 52ec5b8b3..812536074 100644
--- a/src/zenserver/frontend/frontend.cpp
+++ b/src/zenserver/frontend/frontend.cpp
@@ -160,7 +160,7 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
ContentType = ParseContentType(DotExt);
- // Extensions used only for static file serving — not in the global
+ // Extensions used only for static file serving - not in the global
// ParseContentType table because that table also drives URI extension
// stripping for content negotiation, and we don't want /api/foo.txt to
// have its extension removed.
diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html
deleted file mode 100644
index c07bbb692..000000000
--- a/src/zenserver/frontend/html/compute/compute.html
+++ /dev/null
@@ -1,925 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-<head>
- <meta charset="UTF-8">
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
- <title>Zen Compute Dashboard</title>
- <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script>
- <link rel="stylesheet" type="text/css" href="../zen.css" />
- <script src="../util/sanitize.js"></script>
- <script src="../theme.js"></script>
- <script src="../banner.js" defer></script>
- <script src="../nav.js" defer></script>
- <style>
- .grid {
- grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
- }
-
- .chart-container {
- position: relative;
- height: 300px;
- margin-top: 20px;
- }
-
- .stats-row {
- display: flex;
- justify-content: space-between;
- margin-bottom: 12px;
- padding: 8px 0;
- border-bottom: 1px solid var(--theme_border_subtle);
- }
-
- .stats-row:last-child {
- border-bottom: none;
- margin-bottom: 0;
- }
-
- .stats-label {
- color: var(--theme_g1);
- font-size: 13px;
- }
-
- .stats-value {
- color: var(--theme_bright);
- font-weight: 600;
- font-size: 13px;
- }
-
- .rate-stats {
- display: grid;
- grid-template-columns: repeat(3, 1fr);
- gap: 16px;
- margin-top: 16px;
- }
-
- .rate-item {
- text-align: center;
- }
-
- .rate-value {
- font-size: 20px;
- font-weight: 600;
- color: var(--theme_p0);
- }
-
- .rate-label {
- font-size: 11px;
- color: var(--theme_g1);
- margin-top: 4px;
- text-transform: uppercase;
- }
-
- .worker-row {
- cursor: pointer;
- transition: background 0.15s;
- }
-
- .worker-row:hover {
- background: var(--theme_p4);
- }
-
- .worker-row.selected {
- background: var(--theme_p3);
- }
-
- .worker-detail {
- margin-top: 20px;
- border-top: 1px solid var(--theme_g2);
- padding-top: 16px;
- }
-
- .worker-detail-title {
- font-size: 15px;
- font-weight: 600;
- color: var(--theme_bright);
- margin-bottom: 12px;
- }
-
- .detail-section {
- margin-bottom: 16px;
- }
-
- .detail-section-label {
- font-size: 11px;
- font-weight: 600;
- color: var(--theme_g1);
- text-transform: uppercase;
- letter-spacing: 0.5px;
- margin-bottom: 6px;
- }
-
- .detail-table {
- width: 100%;
- border-collapse: collapse;
- font-size: 12px;
- }
-
- .detail-table td {
- padding: 4px 8px;
- color: var(--theme_g0);
- border-bottom: 1px solid var(--theme_border_subtle);
- vertical-align: top;
- }
-
- .detail-table td:first-child {
- color: var(--theme_g1);
- width: 40%;
- font-family: monospace;
- }
-
- .detail-table tr:last-child td {
- border-bottom: none;
- }
-
- .detail-mono {
- font-family: monospace;
- font-size: 11px;
- color: var(--theme_g1);
- }
-
- .detail-tag {
- display: inline-block;
- padding: 2px 8px;
- border-radius: 4px;
- background: var(--theme_border_subtle);
- color: var(--theme_g0);
- font-size: 11px;
- margin: 2px 4px 2px 0;
- }
- </style>
-</head>
-<body>
- <div class="container" style="max-width: 1400px; margin: 0 auto;">
- <zen-banner cluster-status="nominal" load="0" tagline="Node Overview" logo-src="../favicon.ico"></zen-banner>
- <zen-nav>
- <a href="/dashboard/">Home</a>
- <a href="compute.html">Node</a>
- <a href="orchestrator.html">Orchestrator</a>
- </zen-nav>
- <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
-
- <div id="error-container"></div>
-
- <!-- Action Queue Stats -->
- <div class="section-title">Action Queue</div>
- <div class="grid">
- <div class="card">
- <div class="card-title">Pending Actions</div>
- <div class="metric-value" id="actions-pending">-</div>
- <div class="metric-label">Waiting to be scheduled</div>
- </div>
- <div class="card">
- <div class="card-title">Running Actions</div>
- <div class="metric-value" id="actions-running">-</div>
- <div class="metric-label">Currently executing</div>
- </div>
- <div class="card">
- <div class="card-title">Completed Actions</div>
- <div class="metric-value" id="actions-complete">-</div>
- <div class="metric-label">Results available</div>
- </div>
- </div>
-
- <!-- Action Queue Chart -->
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Action Queue History</div>
- <div class="chart-container">
- <canvas id="queue-chart"></canvas>
- </div>
- </div>
-
- <!-- Performance Metrics -->
- <div class="section-title">Performance Metrics</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Completion Rate</div>
- <div class="rate-stats">
- <div class="rate-item">
- <div class="rate-value" id="rate-1">-</div>
- <div class="rate-label">1 min rate</div>
- </div>
- <div class="rate-item">
- <div class="rate-value" id="rate-5">-</div>
- <div class="rate-label">5 min rate</div>
- </div>
- <div class="rate-item">
- <div class="rate-value" id="rate-15">-</div>
- <div class="rate-label">15 min rate</div>
- </div>
- </div>
- <div style="margin-top: 20px;">
- <div class="stats-row">
- <span class="stats-label">Total Retired</span>
- <span class="stats-value" id="retired-count">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Mean Rate</span>
- <span class="stats-value" id="rate-mean">-</span>
- </div>
- </div>
- </div>
-
- <!-- Workers -->
- <div class="section-title">Workers</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Worker Status</div>
- <div class="stats-row">
- <span class="stats-label">Registered Workers</span>
- <span class="stats-value" id="worker-count">-</span>
- </div>
- <div id="worker-table-container" style="margin-top: 16px; display: none;">
- <table id="worker-table">
- <thead>
- <tr>
- <th>Name</th>
- <th>Platform</th>
- <th style="text-align: right;">Cores</th>
- <th style="text-align: right;">Timeout</th>
- <th style="text-align: right;">Functions</th>
- <th>Worker ID</th>
- </tr>
- </thead>
- <tbody id="worker-table-body"></tbody>
- </table>
- <div id="worker-detail" class="worker-detail" style="display: none;"></div>
- </div>
- </div>
-
- <!-- Queues -->
- <div class="section-title">Queues</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Queue Status</div>
- <div id="queue-list-empty" class="empty-state" style="text-align: left;">No queues.</div>
- <div id="queue-list-container" style="display: none;">
- <table id="queue-list-table">
- <thead>
- <tr>
- <th style="text-align: right; width: 60px;">ID</th>
- <th style="text-align: center; width: 80px;">Status</th>
- <th style="text-align: right;">Active</th>
- <th style="text-align: right;">Completed</th>
- <th style="text-align: right;">Failed</th>
- <th style="text-align: right;">Abandoned</th>
- <th style="text-align: right;">Cancelled</th>
- <th>Token</th>
- </tr>
- </thead>
- <tbody id="queue-list-body"></tbody>
- </table>
- </div>
- </div>
-
- <!-- Action History -->
- <div class="section-title">Recent Actions</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Action History</div>
- <div id="action-history-empty" class="empty-state" style="text-align: left;">No actions recorded yet.</div>
- <div id="action-history-container" style="display: none;">
- <table id="action-history-table">
- <thead>
- <tr>
- <th style="text-align: right; width: 60px;">LSN</th>
- <th style="text-align: right; width: 60px;">Queue</th>
- <th style="text-align: center; width: 70px;">Status</th>
- <th>Function</th>
- <th style="text-align: right; width: 80px;">Started</th>
- <th style="text-align: right; width: 80px;">Finished</th>
- <th style="text-align: right; width: 80px;">Duration</th>
- <th>Worker ID</th>
- <th>Action ID</th>
- </tr>
- </thead>
- <tbody id="action-history-body"></tbody>
- </table>
- </div>
- </div>
-
- <!-- System Resources -->
- <div class="section-title">System Resources</div>
- <div class="grid">
- <div class="card">
- <div class="card-title">CPU Usage</div>
- <div class="metric-value" id="cpu-usage">-</div>
- <div class="metric-label">Percent</div>
- <div class="progress-bar">
- <div class="progress-fill" id="cpu-progress" style="width: 0%"></div>
- </div>
- <div style="position: relative; height: 60px; margin-top: 12px;">
- <canvas id="cpu-chart"></canvas>
- </div>
- <div style="margin-top: 12px;">
- <div class="stats-row">
- <span class="stats-label">Packages</span>
- <span class="stats-value" id="cpu-packages">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Physical Cores</span>
- <span class="stats-value" id="cpu-cores">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Logical Processors</span>
- <span class="stats-value" id="cpu-lp">-</span>
- </div>
- </div>
- </div>
- <div class="card">
- <div class="card-title">Memory</div>
- <div class="stats-row">
- <span class="stats-label">Used</span>
- <span class="stats-value" id="memory-used">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Total</span>
- <span class="stats-value" id="memory-total">-</span>
- </div>
- <div class="progress-bar">
- <div class="progress-fill" id="memory-progress" style="width: 0%"></div>
- </div>
- </div>
- <div class="card">
- <div class="card-title">Disk</div>
- <div class="stats-row">
- <span class="stats-label">Used</span>
- <span class="stats-value" id="disk-used">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Total</span>
- <span class="stats-value" id="disk-total">-</span>
- </div>
- <div class="progress-bar">
- <div class="progress-fill" id="disk-progress" style="width: 0%"></div>
- </div>
- </div>
- </div>
- </div>
-
- <script>
- // Configuration
- const BASE_URL = window.location.origin;
- const REFRESH_INTERVAL = 2000; // 2 seconds
- const MAX_HISTORY_POINTS = 60; // Show last 2 minutes
-
- // Data storage
- const history = {
- timestamps: [],
- pending: [],
- running: [],
- completed: [],
- cpu: []
- };
-
- // CPU sparkline chart
- const cpuCtx = document.getElementById('cpu-chart').getContext('2d');
- const cpuChart = new Chart(cpuCtx, {
- type: 'line',
- data: {
- labels: [],
- datasets: [{
- data: [],
- borderColor: '#58a6ff',
- backgroundColor: 'rgba(88, 166, 255, 0.15)',
- borderWidth: 1.5,
- tension: 0.4,
- fill: true,
- pointRadius: 0
- }]
- },
- options: {
- responsive: true,
- maintainAspectRatio: false,
- animation: false,
- plugins: { legend: { display: false }, tooltip: { enabled: false } },
- scales: {
- x: { display: false },
- y: { display: false, min: 0, max: 100 }
- }
- }
- });
-
- // Queue chart setup
- const ctx = document.getElementById('queue-chart').getContext('2d');
- const chart = new Chart(ctx, {
- type: 'line',
- data: {
- labels: [],
- datasets: [
- {
- label: 'Pending',
- data: [],
- borderColor: '#f0883e',
- backgroundColor: 'rgba(240, 136, 62, 0.1)',
- tension: 0.4,
- fill: true
- },
- {
- label: 'Running',
- data: [],
- borderColor: '#58a6ff',
- backgroundColor: 'rgba(88, 166, 255, 0.1)',
- tension: 0.4,
- fill: true
- },
- {
- label: 'Completed',
- data: [],
- borderColor: '#3fb950',
- backgroundColor: 'rgba(63, 185, 80, 0.1)',
- tension: 0.4,
- fill: true
- }
- ]
- },
- options: {
- responsive: true,
- maintainAspectRatio: false,
- plugins: {
- legend: {
- display: true,
- labels: {
- color: '#8b949e'
- }
- }
- },
- scales: {
- x: {
- display: false
- },
- y: {
- beginAtZero: true,
- ticks: {
- color: '#8b949e'
- },
- grid: {
- color: '#21262d'
- }
- }
- }
- }
- });
-
- // Helper functions
-
- function formatBytes(bytes) {
- if (bytes === 0) return '0 B';
- const k = 1024;
- const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
- const i = Math.floor(Math.log(bytes) / Math.log(k));
- return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
- }
-
- function formatRate(rate) {
- return rate.toFixed(2) + '/s';
- }
-
- function showError(message) {
- const container = document.getElementById('error-container');
- container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`;
- }
-
- function clearError() {
- document.getElementById('error-container').innerHTML = '';
- }
-
- function updateTimestamp() {
- const now = new Date();
- document.getElementById('last-update').textContent = now.toLocaleTimeString();
- }
-
- // Fetch functions
- async function fetchJSON(endpoint) {
- const response = await fetch(`${BASE_URL}${endpoint}`, {
- headers: {
- 'Accept': 'application/json'
- }
- });
- if (!response.ok) {
- throw new Error(`HTTP ${response.status}: ${response.statusText}`);
- }
- return await response.json();
- }
-
- async function fetchHealth() {
- try {
- const response = await fetch(`${BASE_URL}/compute/ready`);
- const isHealthy = response.status === 200;
-
- const banner = document.querySelector('zen-banner');
-
- if (isHealthy) {
- banner.setAttribute('cluster-status', 'nominal');
- banner.setAttribute('load', '0');
- } else {
- banner.setAttribute('cluster-status', 'degraded');
- banner.setAttribute('load', '0');
- }
-
- return isHealthy;
- } catch (error) {
- const banner = document.querySelector('zen-banner');
- banner.setAttribute('cluster-status', 'degraded');
- banner.setAttribute('load', '0');
- throw error;
- }
- }
-
- async function fetchStats() {
- const data = await fetchJSON('/stats/compute');
-
- // Update action counts
- document.getElementById('actions-pending').textContent = data.actions_pending || 0;
- document.getElementById('actions-running').textContent = data.actions_submitted || 0;
- document.getElementById('actions-complete').textContent = data.actions_complete || 0;
-
- // Update completion rates
- if (data.actions_retired) {
- document.getElementById('rate-1').textContent = formatRate(data.actions_retired.rate_1 || 0);
- document.getElementById('rate-5').textContent = formatRate(data.actions_retired.rate_5 || 0);
- document.getElementById('rate-15').textContent = formatRate(data.actions_retired.rate_15 || 0);
- document.getElementById('retired-count').textContent = data.actions_retired.count || 0;
- document.getElementById('rate-mean').textContent = formatRate(data.actions_retired.rate_mean || 0);
- }
-
- // Update chart
- const now = new Date().toLocaleTimeString();
- history.timestamps.push(now);
- history.pending.push(data.actions_pending || 0);
- history.running.push(data.actions_submitted || 0);
- history.completed.push(data.actions_complete || 0);
-
- // Keep only last N points
- if (history.timestamps.length > MAX_HISTORY_POINTS) {
- history.timestamps.shift();
- history.pending.shift();
- history.running.shift();
- history.completed.shift();
- }
-
- chart.data.labels = history.timestamps;
- chart.data.datasets[0].data = history.pending;
- chart.data.datasets[1].data = history.running;
- chart.data.datasets[2].data = history.completed;
- chart.update('none');
- }
-
- async function fetchSysInfo() {
- const data = await fetchJSON('/compute/sysinfo');
-
- // Update CPU
- const cpuUsage = data.cpu_usage || 0;
- document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%';
- document.getElementById('cpu-progress').style.width = cpuUsage + '%';
-
- const banner = document.querySelector('zen-banner');
- banner.setAttribute('load', cpuUsage.toFixed(1));
-
- history.cpu.push(cpuUsage);
- if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift();
- cpuChart.data.labels = history.cpu.map(() => '');
- cpuChart.data.datasets[0].data = history.cpu;
- cpuChart.update('none');
-
- document.getElementById('cpu-packages').textContent = data.cpu_count ?? '-';
- document.getElementById('cpu-cores').textContent = data.core_count ?? '-';
- document.getElementById('cpu-lp').textContent = data.lp_count ?? '-';
-
- // Update Memory
- const memUsed = data.memory_used || 0;
- const memTotal = data.memory_total || 1;
- const memPercent = (memUsed / memTotal) * 100;
- document.getElementById('memory-used').textContent = formatBytes(memUsed);
- document.getElementById('memory-total').textContent = formatBytes(memTotal);
- document.getElementById('memory-progress').style.width = memPercent + '%';
-
- // Update Disk
- const diskUsed = data.disk_used || 0;
- const diskTotal = data.disk_total || 1;
- const diskPercent = (diskUsed / diskTotal) * 100;
- document.getElementById('disk-used').textContent = formatBytes(diskUsed);
- document.getElementById('disk-total').textContent = formatBytes(diskTotal);
- document.getElementById('disk-progress').style.width = diskPercent + '%';
- }
-
- // Persists the selected worker ID across refreshes
- let selectedWorkerId = null;
-
- function renderWorkerDetail(id, desc) {
- const panel = document.getElementById('worker-detail');
-
- if (!desc) {
- panel.style.display = 'none';
- return;
- }
-
- function field(label, value) {
- return `<tr><td>${label}</td><td>${value ?? '-'}</td></tr>`;
- }
-
- function monoField(label, value) {
- return `<tr><td>${label}</td><td class="detail-mono">${value ?? '-'}</td></tr>`;
- }
-
- // Functions
- const functions = desc.functions || [];
- const functionsHtml = functions.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- `<table class="detail-table">${functions.map(f =>
- `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>`
- ).join('')}</table>`;
-
- // Executables
- const executables = desc.executables || [];
- const totalExecSize = executables.reduce((sum, e) => sum + (e.size || 0), 0);
- const execHtml = executables.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- `<table class="detail-table">
- <tr style="font-size:11px;">
- <td style="color:var(--theme_faint);padding-bottom:4px;">Path</td>
- <td style="color:var(--theme_faint);padding-bottom:4px;">Hash</td>
- <td style="color:var(--theme_faint);padding-bottom:4px;text-align:right;">Size</td>
- </tr>
- ${executables.map(e =>
- `<tr>
- <td>${escapeHtml(e.name || '-')}</td>
- <td class="detail-mono">${escapeHtml(e.hash || '-')}</td>
- <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td>
- </tr>`
- ).join('')}
- <tr style="border-top:1px solid var(--theme_g2);">
- <td style="color:var(--theme_g1);padding-top:6px;">Total</td>
- <td></td>
- <td style="text-align:right;white-space:nowrap;padding-top:6px;color:var(--theme_bright);font-weight:600;">${formatBytes(totalExecSize)}</td>
- </tr>
- </table>`;
-
- // Files
- const files = desc.files || [];
- const filesHtml = files.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- `<table class="detail-table">${files.map(f =>
- `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>`
- ).join('')}</table>`;
-
- // Dirs
- const dirs = desc.dirs || [];
- const dirsHtml = dirs.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join('');
-
- // Environment
- const env = desc.environment || [];
- const envHtml = env.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join('');
-
- panel.innerHTML = `
- <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div>
- <div class="detail-section">
- <table class="detail-table">
- ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)}
- ${field('Path', escapeHtml(desc.path || '-'))}
- ${field('Platform', escapeHtml(desc.host || '-'))}
- ${monoField('Build System', desc.buildsystem_version)}
- ${field('Cores', desc.cores)}
- ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)}
- </table>
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Functions</div>
- ${functionsHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Executables</div>
- ${execHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Files</div>
- ${filesHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Directories</div>
- ${dirsHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Environment</div>
- ${envHtml}
- </div>
- `;
- panel.style.display = 'block';
- }
-
- async function fetchWorkers() {
- const data = await fetchJSON('/compute/workers');
- const workerIds = data.workers || [];
-
- document.getElementById('worker-count').textContent = workerIds.length;
-
- const container = document.getElementById('worker-table-container');
- const tbody = document.getElementById('worker-table-body');
-
- if (workerIds.length === 0) {
- container.style.display = 'none';
- selectedWorkerId = null;
- return;
- }
-
- const descriptors = await Promise.all(
- workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null))
- );
-
- // Build a map for quick lookup by ID
- const descriptorMap = {};
- workerIds.forEach((id, i) => { descriptorMap[id] = descriptors[i]; });
-
- tbody.innerHTML = '';
- descriptors.forEach((desc, i) => {
- const id = workerIds[i];
- const name = desc ? (desc.name || '-') : '-';
- const host = desc ? (desc.host || '-') : '-';
- const cores = desc ? (desc.cores != null ? desc.cores : '-') : '-';
- const timeout = desc ? (desc.timeout != null ? desc.timeout + 's' : '-') : '-';
- const functions = desc ? (desc.functions ? desc.functions.length : 0) : '-';
-
- const tr = document.createElement('tr');
- tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : '');
- tr.dataset.workerId = id;
- tr.innerHTML = `
- <td style="color: var(--theme_bright);">${escapeHtml(name)}</td>
- <td>${escapeHtml(host)}</td>
- <td style="text-align: right;">${escapeHtml(String(cores))}</td>
- <td style="text-align: right;">${escapeHtml(String(timeout))}</td>
- <td style="text-align: right;">${escapeHtml(String(functions))}</td>
- <td style="color: var(--theme_g1); font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td>
- `;
- tr.addEventListener('click', () => {
- document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected'));
- if (selectedWorkerId === id) {
- // Toggle off
- selectedWorkerId = null;
- document.getElementById('worker-detail').style.display = 'none';
- } else {
- selectedWorkerId = id;
- tr.classList.add('selected');
- renderWorkerDetail(id, descriptorMap[id]);
- }
- });
- tbody.appendChild(tr);
- });
-
- // Re-render detail if selected worker is still present
- if (selectedWorkerId && descriptorMap[selectedWorkerId]) {
- renderWorkerDetail(selectedWorkerId, descriptorMap[selectedWorkerId]);
- } else if (selectedWorkerId && !descriptorMap[selectedWorkerId]) {
- selectedWorkerId = null;
- document.getElementById('worker-detail').style.display = 'none';
- }
-
- container.style.display = 'block';
- }
-
- // Windows FILETIME: 100ns ticks since 1601-01-01. Convert to JS Date.
- const FILETIME_EPOCH_OFFSET_MS = 11644473600000n;
- function filetimeToDate(ticks) {
- if (!ticks) return null;
- const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS;
- return new Date(Number(ms));
- }
-
- function formatTime(date) {
- if (!date) return '-';
- return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' });
- }
-
- function formatDuration(startDate, endDate) {
- if (!startDate || !endDate) return '-';
- const ms = endDate - startDate;
- if (ms < 0) return '-';
- if (ms < 1000) return ms + ' ms';
- if (ms < 60000) return (ms / 1000).toFixed(2) + ' s';
- const m = Math.floor(ms / 60000);
- const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, '0');
- return `${m}m ${s}s`;
- }
-
- async function fetchQueues() {
- const data = await fetchJSON('/compute/queues');
- const queues = data.queues || [];
-
- const empty = document.getElementById('queue-list-empty');
- const container = document.getElementById('queue-list-container');
- const tbody = document.getElementById('queue-list-body');
-
- if (queues.length === 0) {
- empty.style.display = '';
- container.style.display = 'none';
- return;
- }
-
- empty.style.display = 'none';
- tbody.innerHTML = '';
-
- for (const q of queues) {
- const id = q.queue_id ?? '-';
- const badge = q.state === 'cancelled'
- ? '<span class="status-badge failure">cancelled</span>'
- : q.state === 'draining'
- ? '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_warn) 15%, transparent);color:var(--theme_warn);">draining</span>'
- : q.is_complete
- ? '<span class="status-badge success">complete</span>'
- : '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_p0) 15%, transparent);color:var(--theme_p0);">active</span>';
- const token = q.queue_token
- ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>`
- : '<span style="color:var(--theme_faint);">-</span>';
-
- const tr = document.createElement('tr');
- tr.innerHTML = `
- <td style="text-align: right; font-family: monospace; color: var(--theme_bright);">${escapeHtml(String(id))}</td>
- <td style="text-align: center;">${badge}</td>
- <td style="text-align: right;">${q.active_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_ok);">${q.completed_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_fail);">${q.failed_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_warn);">${q.abandoned_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_warn);">${q.cancelled_count ?? 0}</td>
- <td>${token}</td>
- `;
- tbody.appendChild(tr);
- }
-
- container.style.display = 'block';
- }
-
- async function fetchActionHistory() {
- const data = await fetchJSON('/compute/jobs/history?limit=50');
- const entries = data.history || [];
-
- const empty = document.getElementById('action-history-empty');
- const container = document.getElementById('action-history-container');
- const tbody = document.getElementById('action-history-body');
-
- if (entries.length === 0) {
- empty.style.display = '';
- container.style.display = 'none';
- return;
- }
-
- empty.style.display = 'none';
- tbody.innerHTML = '';
-
- // Entries arrive oldest-first; reverse to show newest at top
- for (const entry of [...entries].reverse()) {
- const lsn = entry.lsn ?? '-';
- const succeeded = entry.succeeded;
- const badge = succeeded == null
- ? '<span class="status-badge" style="background:var(--theme_border_subtle);color:var(--theme_g1);">unknown</span>'
- : succeeded
- ? '<span class="status-badge success">ok</span>'
- : '<span class="status-badge failure">failed</span>';
- const desc = entry.actionDescriptor || {};
- const fn = desc.Function || '-';
- const workerId = entry.workerId || '-';
- const actionId = entry.actionId || '-';
-
- const startDate = filetimeToDate(entry.time_Running);
- const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed);
-
- const queueId = entry.queueId || 0;
- const queueCell = queueId
- ? `<a href="/compute/queues/${queueId}" style="color: var(--theme_ln); text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>`
- : '<span style="color: var(--theme_faint);">-</span>';
-
- const tr = document.createElement('tr');
- tr.innerHTML = `
- <td style="text-align: right; font-family: monospace; color: var(--theme_g1);">${escapeHtml(String(lsn))}</td>
- <td style="text-align: right;">${queueCell}</td>
- <td style="text-align: center;">${badge}</td>
- <td style="color: var(--theme_bright);">${escapeHtml(fn)}</td>
- <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(startDate)}</td>
- <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(endDate)}</td>
- <td style="text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td>
- <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(workerId)}</td>
- <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(actionId)}</td>
- `;
- tbody.appendChild(tr);
- }
-
- container.style.display = 'block';
- }
-
- async function updateDashboard() {
- try {
- await Promise.all([
- fetchHealth(),
- fetchStats(),
- fetchSysInfo(),
- fetchWorkers(),
- fetchQueues(),
- fetchActionHistory()
- ]);
-
- clearError();
- updateTimestamp();
- } catch (error) {
- console.error('Error updating dashboard:', error);
- showError(error.message);
- }
- }
-
- // Start updating
- updateDashboard();
- setInterval(updateDashboard, REFRESH_INTERVAL);
- </script>
-</body>
-</html>
diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html
index b15b34577..41c80d3a3 100644
--- a/src/zenserver/frontend/html/compute/hub.html
+++ b/src/zenserver/frontend/html/compute/hub.html
@@ -83,7 +83,7 @@
}
async function fetchStats() {
- var data = await fetchJSON('/hub/stats');
+ var data = await fetchJSON('/stats/hub');
var current = data.currentInstanceCount || 0;
var max = data.maxInstanceCount || 0;
diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html
index 9597fd7f3..aaa09aec0 100644
--- a/src/zenserver/frontend/html/compute/index.html
+++ b/src/zenserver/frontend/html/compute/index.html
@@ -1 +1 @@
-<meta http-equiv="refresh" content="0; url=compute.html" /> \ No newline at end of file
+<meta http-equiv="refresh" content="0; url=/dashboard/?page=compute" /> \ No newline at end of file
diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html
deleted file mode 100644
index d1a2bb015..000000000
--- a/src/zenserver/frontend/html/compute/orchestrator.html
+++ /dev/null
@@ -1,669 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-<head>
- <meta charset="UTF-8">
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
- <link rel="stylesheet" type="text/css" href="../zen.css" />
- <script src="../util/sanitize.js"></script>
- <script src="../theme.js"></script>
- <script src="../banner.js" defer></script>
- <script src="../nav.js" defer></script>
- <title>Zen Orchestrator Dashboard</title>
- <style>
- .agent-count {
- display: flex;
- align-items: center;
- gap: 8px;
- font-size: 14px;
- padding: 8px 16px;
- border-radius: 6px;
- background: var(--theme_g3);
- border: 1px solid var(--theme_g2);
- }
-
- .agent-count .count {
- font-size: 20px;
- font-weight: 600;
- color: var(--theme_bright);
- }
- </style>
-</head>
-<body>
- <div class="container" style="max-width: 1400px; margin: 0 auto;">
- <zen-banner cluster-status="nominal" load="0" logo-src="../favicon.ico"></zen-banner>
- <zen-nav>
- <a href="/dashboard/">Home</a>
- <a href="compute.html">Node</a>
- <a href="orchestrator.html">Orchestrator</a>
- </zen-nav>
- <div class="header">
- <div>
- <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
- </div>
- <div class="agent-count">
- <span>Agents:</span>
- <span class="count" id="agent-count">-</span>
- </div>
- </div>
-
- <div id="error-container"></div>
-
- <div class="card">
- <div class="card-title">Compute Agents</div>
- <div id="empty-state" class="empty-state">No agents registered.</div>
- <table id="agent-table" style="display: none;">
- <thead>
- <tr>
- <th style="width: 40px; text-align: center;">Health</th>
- <th>Hostname</th>
- <th style="text-align: right;">CPUs</th>
- <th style="text-align: right;">CPU Usage</th>
- <th style="text-align: right;">Memory</th>
- <th style="text-align: right;">Queues</th>
- <th style="text-align: right;">Pending</th>
- <th style="text-align: right;">Running</th>
- <th style="text-align: right;">Completed</th>
- <th style="text-align: right;">Traffic</th>
- <th style="text-align: right;">Last Seen</th>
- </tr>
- </thead>
- <tbody id="agent-table-body"></tbody>
- </table>
- </div>
- <div class="card" style="margin-top: 20px;">
- <div class="card-title">Connected Clients</div>
- <div id="clients-empty" class="empty-state">No clients connected.</div>
- <table id="clients-table" style="display: none;">
- <thead>
- <tr>
- <th style="width: 40px; text-align: center;">Health</th>
- <th>Client ID</th>
- <th>Hostname</th>
- <th>Address</th>
- <th style="text-align: right;">Last Seen</th>
- </tr>
- </thead>
- <tbody id="clients-table-body"></tbody>
- </table>
- </div>
- <div class="card" style="margin-top: 20px;">
- <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;">
- <div class="card-title" style="margin-bottom: 0;">Event History</div>
- <div class="history-tabs">
- <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button>
- <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button>
- </div>
- </div>
- <div id="history-panel-workers">
- <div id="history-empty" class="empty-state">No provisioning events recorded.</div>
- <table id="history-table" style="display: none;">
- <thead>
- <tr>
- <th>Time</th>
- <th>Event</th>
- <th>Worker</th>
- <th>Hostname</th>
- </tr>
- </thead>
- <tbody id="history-table-body"></tbody>
- </table>
- </div>
- <div id="history-panel-clients" style="display: none;">
- <div id="client-history-empty" class="empty-state">No client events recorded.</div>
- <table id="client-history-table" style="display: none;">
- <thead>
- <tr>
- <th>Time</th>
- <th>Event</th>
- <th>Client</th>
- <th>Hostname</th>
- </tr>
- </thead>
- <tbody id="client-history-table-body"></tbody>
- </table>
- </div>
- </div>
- </div>
-
- <script>
- const BASE_URL = window.location.origin;
- const REFRESH_INTERVAL = 2000;
-
- function showError(message) {
- document.getElementById('error-container').innerHTML =
- '<div class="error">Error: ' + escapeHtml(message) + '</div>';
- }
-
- function clearError() {
- document.getElementById('error-container').innerHTML = '';
- }
-
- function formatLastSeen(dtMs) {
- if (dtMs == null) return '-';
- var seconds = Math.floor(dtMs / 1000);
- if (seconds < 60) return seconds + 's ago';
- var minutes = Math.floor(seconds / 60);
- if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago';
- var hours = Math.floor(minutes / 60);
- return hours + 'h ' + (minutes % 60) + 'm ago';
- }
-
- function healthClass(dtMs, reachable) {
- if (reachable === false) return 'health-red';
- if (dtMs == null) return 'health-red';
- var seconds = dtMs / 1000;
- if (seconds < 30 && reachable === true) return 'health-green';
- if (seconds < 120) return 'health-yellow';
- return 'health-red';
- }
-
- function healthTitle(dtMs, reachable) {
- var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen';
- if (reachable === true) return seenStr + ' · Reachable';
- if (reachable === false) return seenStr + ' · Unreachable';
- return seenStr + ' · Reachability unknown';
- }
-
- function formatCpuUsage(percent) {
- if (percent == null || percent === 0) return '-';
- return percent.toFixed(1) + '%';
- }
-
- function formatMemory(usedBytes, totalBytes) {
- if (!totalBytes) return '-';
- var usedGiB = usedBytes / (1024 * 1024 * 1024);
- var totalGiB = totalBytes / (1024 * 1024 * 1024);
- return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB';
- }
-
- function formatBytes(bytes) {
- if (!bytes) return '-';
- if (bytes < 1024) return bytes + ' B';
- if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB';
- if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB';
- if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB';
- return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB';
- }
-
- function formatTraffic(recv, sent) {
- if (!recv && !sent) return '-';
- return formatBytes(recv) + ' / ' + formatBytes(sent);
- }
-
- function parseIpFromUri(uri) {
- try {
- var url = new URL(uri);
- var host = url.hostname;
- // Strip IPv6 brackets
- if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1);
- // Only handle IPv4
- var parts = host.split('.');
- if (parts.length !== 4) return null;
- var octets = parts.map(Number);
- if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null;
- return octets;
- } catch (e) {
- return null;
- }
- }
-
- function computeCidr(ips) {
- if (ips.length === 0) return null;
- if (ips.length === 1) return ips[0].join('.') + '/32';
-
- // Convert each IP to a 32-bit integer
- var ints = ips.map(function(o) {
- return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0;
- });
-
- // Find common prefix length by ANDing all identical high bits
- var common = ~0 >>> 0;
- for (var i = 1; i < ints.length; i++) {
- // XOR to find differing bits, then mask away everything from the first difference down
- var diff = (ints[0] ^ ints[i]) >>> 0;
- if (diff !== 0) {
- var bit = 31 - Math.floor(Math.log2(diff));
- var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0;
- common = (common & mask) >>> 0;
- }
- }
-
- // Count leading ones in the common mask
- var prefix = 0;
- for (var b = 31; b >= 0; b--) {
- if ((common >>> b) & 1) prefix++;
- else break;
- }
-
- // Network address
- var net = (ints[0] & common) >>> 0;
- var a = (net >>> 24) & 0xff;
- var bv = (net >>> 16) & 0xff;
- var c = (net >>> 8) & 0xff;
- var d = net & 0xff;
- return a + '.' + bv + '.' + c + '.' + d + '/' + prefix;
- }
-
- function renderDashboard(data) {
- var banner = document.querySelector('zen-banner');
- if (data.hostname) {
- banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname);
- }
- var workers = data.workers || [];
-
- document.getElementById('agent-count').textContent = workers.length;
-
- if (workers.length === 0) {
- banner.setAttribute('cluster-status', 'degraded');
- banner.setAttribute('load', '0');
- } else {
- banner.setAttribute('cluster-status', 'nominal');
- }
-
- var emptyState = document.getElementById('empty-state');
- var table = document.getElementById('agent-table');
- var tbody = document.getElementById('agent-table-body');
-
- if (workers.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- } else {
- emptyState.style.display = 'none';
- table.style.display = '';
-
- tbody.innerHTML = '';
- var totalCpus = 0;
- var totalWeightedCpuUsage = 0;
- var totalMemUsed = 0;
- var totalMemTotal = 0;
- var totalQueues = 0;
- var totalPending = 0;
- var totalRunning = 0;
- var totalCompleted = 0;
- var totalBytesRecv = 0;
- var totalBytesSent = 0;
- var allIps = [];
- for (var i = 0; i < workers.length; i++) {
- var w = workers[i];
- var uri = w.uri || '';
- var dt = w.dt;
- var dashboardUrl = uri + '/dashboard/compute/';
-
- var id = w.id || '';
-
- var hostname = w.hostname || '';
- var cpus = w.cpus || 0;
- totalCpus += cpus;
- if (cpus > 0 && typeof w.cpu_usage === 'number') {
- totalWeightedCpuUsage += w.cpu_usage * cpus;
- }
-
- var memTotal = w.memory_total || 0;
- var memUsed = w.memory_used || 0;
- totalMemTotal += memTotal;
- totalMemUsed += memUsed;
-
- var activeQueues = w.active_queues || 0;
- totalQueues += activeQueues;
-
- var actionsPending = w.actions_pending || 0;
- var actionsRunning = w.actions_running || 0;
- var actionsCompleted = w.actions_completed || 0;
- totalPending += actionsPending;
- totalRunning += actionsRunning;
- totalCompleted += actionsCompleted;
-
- var bytesRecv = w.bytes_received || 0;
- var bytesSent = w.bytes_sent || 0;
- totalBytesRecv += bytesRecv;
- totalBytesSent += bytesSent;
-
- var ip = parseIpFromUri(uri);
- if (ip) allIps.push(ip);
-
- var reachable = w.reachable;
- var hClass = healthClass(dt, reachable);
- var hTitle = healthTitle(dt, reachable);
-
- var platform = w.platform || '';
- var badges = '';
- if (platform) {
- var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' };
- var platColor = platColors[platform] || '#8b949e';
- badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>';
- }
- var provisioner = w.provisioner || '';
- if (provisioner) {
- var provColors = { horde: '#8957e5', nomad: '#3fb950' };
- var provColor = provColors[provisioner] || '#8b949e';
- badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>';
- }
-
- var tr = document.createElement('tr');
- tr.title = id;
- tr.innerHTML =
- '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' +
- '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' +
- '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' +
- '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' +
- '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' +
- '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' +
- '<td style="text-align: right;">' + actionsPending + '</td>' +
- '<td style="text-align: right;">' + actionsRunning + '</td>' +
- '<td style="text-align: right;">' + actionsCompleted + '</td>' +
- '<td style="text-align: right; font-size: 11px; color: var(--theme_g1);">' + formatTraffic(bytesRecv, bytesSent) + '</td>' +
- '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>';
- tbody.appendChild(tr);
- }
-
- var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0;
- banner.setAttribute('load', clusterLoad.toFixed(1));
-
- // Total row
- var cidr = computeCidr(allIps);
- var totalTr = document.createElement('tr');
- totalTr.className = 'total-row';
- totalTr.innerHTML =
- '<td></td>' +
- '<td style="text-align: right; color: var(--theme_g1); text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' +
- '<td style="text-align: right;">' + totalCpus + '</td>' +
- '<td></td>' +
- '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' +
- '<td style="text-align: right;">' + totalQueues + '</td>' +
- '<td style="text-align: right;">' + totalPending + '</td>' +
- '<td style="text-align: right;">' + totalRunning + '</td>' +
- '<td style="text-align: right;">' + totalCompleted + '</td>' +
- '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' +
- '<td></td>';
- tbody.appendChild(totalTr);
- }
-
- clearError();
- document.getElementById('last-update').textContent = new Date().toLocaleTimeString();
-
- // Render provisioning history if present in WebSocket payload
- if (data.events) {
- renderProvisioningHistory(data.events);
- }
-
- // Render connected clients if present
- if (data.clients) {
- renderClients(data.clients);
- }
-
- // Render client history if present
- if (data.client_events) {
- renderClientHistory(data.client_events);
- }
- }
-
- function eventBadge(type) {
- var colors = { joined: 'var(--theme_ok)', left: 'var(--theme_fail)', returned: 'var(--theme_warn)' };
- var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' };
- var color = colors[type] || 'var(--theme_g1)';
- var label = labels[type] || type;
- return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>';
- }
-
- function formatTimestamp(ts) {
- if (!ts) return '-';
- // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string
- var date;
- if (typeof ts === 'number') {
- // .NET-style ticks: convert to Unix ms
- var unixMs = (ts - 621355968000000000) / 10000;
- date = new Date(unixMs);
- } else {
- date = new Date(ts);
- }
- if (isNaN(date.getTime())) return '-';
- return date.toLocaleTimeString();
- }
-
- var activeHistoryTab = 'workers';
-
- function switchHistoryTab(tab) {
- activeHistoryTab = tab;
- var tabs = document.querySelectorAll('.history-tab');
- for (var i = 0; i < tabs.length; i++) {
- tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab);
- }
- document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none';
- document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none';
- }
-
- function renderProvisioningHistory(events) {
- var emptyState = document.getElementById('history-empty');
- var table = document.getElementById('history-table');
- var tbody = document.getElementById('history-table-body');
-
- if (!events || events.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- return;
- }
-
- emptyState.style.display = 'none';
- table.style.display = '';
- tbody.innerHTML = '';
-
- for (var i = 0; i < events.length; i++) {
- var evt = events[i];
- var tr = document.createElement('tr');
- tr.innerHTML =
- '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' +
- '<td>' + eventBadge(evt.type) + '</td>' +
- '<td>' + escapeHtml(evt.worker_id || '') + '</td>' +
- '<td>' + escapeHtml(evt.hostname || '') + '</td>';
- tbody.appendChild(tr);
- }
- }
-
- function clientHealthClass(dtMs) {
- if (dtMs == null) return 'health-red';
- var seconds = dtMs / 1000;
- if (seconds < 30) return 'health-green';
- if (seconds < 120) return 'health-yellow';
- return 'health-red';
- }
-
- function renderClients(clients) {
- var emptyState = document.getElementById('clients-empty');
- var table = document.getElementById('clients-table');
- var tbody = document.getElementById('clients-table-body');
-
- if (!clients || clients.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- return;
- }
-
- emptyState.style.display = 'none';
- table.style.display = '';
- tbody.innerHTML = '';
-
- for (var i = 0; i < clients.length; i++) {
- var c = clients[i];
- var dt = c.dt;
- var hClass = clientHealthClass(dt);
- var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen';
-
- var sessionBadge = '';
- if (c.session_id) {
- sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:var(--theme_faint);" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>';
- }
-
- var tr = document.createElement('tr');
- tr.innerHTML =
- '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' +
- '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' +
- '<td>' + escapeHtml(c.hostname || '') + '</td>' +
- '<td style="font-family: monospace; font-size: 12px; color: var(--theme_g1);">' + escapeHtml(c.address || '') + '</td>' +
- '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>';
- tbody.appendChild(tr);
- }
- }
-
- function clientEventBadge(type) {
- var colors = { connected: 'var(--theme_ok)', disconnected: 'var(--theme_fail)', updated: 'var(--theme_warn)' };
- var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' };
- var color = colors[type] || 'var(--theme_g1)';
- var label = labels[type] || type;
- return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>';
- }
-
- function renderClientHistory(events) {
- var emptyState = document.getElementById('client-history-empty');
- var table = document.getElementById('client-history-table');
- var tbody = document.getElementById('client-history-table-body');
-
- if (!events || events.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- return;
- }
-
- emptyState.style.display = 'none';
- table.style.display = '';
- tbody.innerHTML = '';
-
- for (var i = 0; i < events.length; i++) {
- var evt = events[i];
- var tr = document.createElement('tr');
- tr.innerHTML =
- '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' +
- '<td>' + clientEventBadge(evt.type) + '</td>' +
- '<td>' + escapeHtml(evt.client_id || '') + '</td>' +
- '<td>' + escapeHtml(evt.hostname || '') + '</td>';
- tbody.appendChild(tr);
- }
- }
-
- // Fetch-based polling fallback
- var pollTimer = null;
-
- async function fetchProvisioningHistory() {
- try {
- var response = await fetch(BASE_URL + '/orch/history?limit=50', {
- headers: { 'Accept': 'application/json' }
- });
- if (response.ok) {
- var data = await response.json();
- renderProvisioningHistory(data.events || []);
- }
- } catch (e) {
- console.error('Error fetching provisioning history:', e);
- }
- }
-
- async function fetchClients() {
- try {
- var response = await fetch(BASE_URL + '/orch/clients', {
- headers: { 'Accept': 'application/json' }
- });
- if (response.ok) {
- var data = await response.json();
- renderClients(data.clients || []);
- }
- } catch (e) {
- console.error('Error fetching clients:', e);
- }
- }
-
- async function fetchClientHistory() {
- try {
- var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', {
- headers: { 'Accept': 'application/json' }
- });
- if (response.ok) {
- var data = await response.json();
- renderClientHistory(data.client_events || []);
- }
- } catch (e) {
- console.error('Error fetching client history:', e);
- }
- }
-
- async function fetchDashboard() {
- var banner = document.querySelector('zen-banner');
- try {
- var response = await fetch(BASE_URL + '/orch/agents', {
- headers: { 'Accept': 'application/json' }
- });
-
- if (!response.ok) {
- banner.setAttribute('cluster-status', 'degraded');
- throw new Error('HTTP ' + response.status + ': ' + response.statusText);
- }
-
- renderDashboard(await response.json());
- fetchProvisioningHistory();
- fetchClients();
- fetchClientHistory();
- } catch (error) {
- console.error('Error updating dashboard:', error);
- showError(error.message);
- banner.setAttribute('cluster-status', 'offline');
- }
- }
-
- function startPolling() {
- if (pollTimer) return;
- fetchDashboard();
- pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL);
- }
-
- function stopPolling() {
- if (pollTimer) {
- clearInterval(pollTimer);
- pollTimer = null;
- }
- }
-
- // WebSocket connection with automatic reconnect and polling fallback
- var ws = null;
-
- function connectWebSocket() {
- var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
- ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws');
-
- ws.onopen = function() {
- stopPolling();
- clearError();
- };
-
- ws.onmessage = function(event) {
- try {
- renderDashboard(JSON.parse(event.data));
- } catch (e) {
- console.error('WebSocket message parse error:', e);
- }
- };
-
- ws.onclose = function() {
- ws = null;
- startPolling();
- setTimeout(connectWebSocket, 3000);
- };
-
- ws.onerror = function() {
- // onclose will fire after onerror
- };
- }
-
- // Fetch orchestrator hostname for the banner
- fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } })
- .then(function(r) { return r.ok ? r.json() : null; })
- .then(function(d) {
- if (d && d.hostname) {
- document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname);
- }
- })
- .catch(function() {});
-
- // Initial load via fetch, then try WebSocket
- fetchDashboard();
- connectWebSocket();
- </script>
-</body>
-</html>
diff --git a/src/zenserver/frontend/html/pages/builds.js b/src/zenserver/frontend/html/pages/builds.js
index 095f0bf29..c63d13b91 100644
--- a/src/zenserver/frontend/html/pages/builds.js
+++ b/src/zenserver/frontend/html/pages/builds.js
@@ -16,7 +16,7 @@ export class Page extends ZenPage
this.set_title("build store");
// Build Store Stats
- const stats_section = this.add_section("Build Store Stats");
+ const stats_section = this._collapsible_section("Build Store Service Stats");
stats_section.tag().classify("dropall").text("raw yaml \u2192").on_click(() => {
window.open("/stats/builds.yaml", "_blank");
});
@@ -39,6 +39,7 @@ export class Page extends ZenPage
_render_stats(stats)
{
+ stats = this._merge_last_stats(stats);
const grid = this._stats_grid;
const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
@@ -49,39 +50,30 @@ export class Page extends ZenPage
// Build Store tile
{
- const blobs = safe(stats, "store.blobs");
- const metadata = safe(stats, "store.metadata");
- if (blobs || metadata)
- {
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Build Store");
- const columns = tile.tag().classify("tile-columns");
+ const blobs = safe(stats, "store.blobs") || {};
+ const metadata = safe(stats, "store.metadata") || {};
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Build Store");
+ const columns = tile.tag().classify("tile-columns");
- const left = columns.tag().classify("tile-metrics");
- this._metric(left, Friendly.bytes(safe(stats, "store.size.disk") || 0), "disk", true);
- if (blobs)
- {
- this._metric(left, Friendly.sep(blobs.count || 0), "blobs");
- this._metric(left, Friendly.sep(blobs.readcount || 0), "blob reads");
- this._metric(left, Friendly.sep(blobs.writecount || 0), "blob writes");
- const blobHitRatio = (blobs.readcount || 0) > 0
- ? (((blobs.hitcount || 0) / blobs.readcount) * 100).toFixed(1) + "%"
- : "-";
- this._metric(left, blobHitRatio, "blob hit ratio");
- }
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, Friendly.bytes(safe(stats, "store.size.disk") || 0), "disk", true);
+ this._metric(left, Friendly.sep(blobs.count || 0), "blobs");
+ this._metric(left, Friendly.sep(blobs.readcount || 0), "blob reads");
+ this._metric(left, Friendly.sep(blobs.writecount || 0), "blob writes");
+ const blobHitRatio = (blobs.readcount || 0) > 0
+ ? (((blobs.hitcount || 0) / blobs.readcount) * 100).toFixed(1) + "%"
+ : "-";
+ this._metric(left, blobHitRatio, "blob hit ratio");
- const right = columns.tag().classify("tile-metrics");
- if (metadata)
- {
- this._metric(right, Friendly.sep(metadata.count || 0), "metadata entries", true);
- this._metric(right, Friendly.sep(metadata.readcount || 0), "meta reads");
- this._metric(right, Friendly.sep(metadata.writecount || 0), "meta writes");
- const metaHitRatio = (metadata.readcount || 0) > 0
- ? (((metadata.hitcount || 0) / metadata.readcount) * 100).toFixed(1) + "%"
- : "-";
- this._metric(right, metaHitRatio, "meta hit ratio");
- }
- }
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, Friendly.sep(metadata.count || 0), "metadata entries", true);
+ this._metric(right, Friendly.sep(metadata.readcount || 0), "meta reads");
+ this._metric(right, Friendly.sep(metadata.writecount || 0), "meta writes");
+ const metaHitRatio = (metadata.readcount || 0) > 0
+ ? (((metadata.hitcount || 0) / metadata.readcount) * 100).toFixed(1) + "%"
+ : "-";
+ this._metric(right, metaHitRatio, "meta hit ratio");
}
}
diff --git a/src/zenserver/frontend/html/pages/cache.js b/src/zenserver/frontend/html/pages/cache.js
index 1fc8227c8..683f7df4f 100644
--- a/src/zenserver/frontend/html/pages/cache.js
+++ b/src/zenserver/frontend/html/pages/cache.js
@@ -6,7 +6,7 @@ import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
import { Modal } from "../util/modal.js"
-import { Table, Toolbar } from "../util/widgets.js"
+import { Table, Toolbar, Pager, add_copy_button } from "../util/widgets.js"
////////////////////////////////////////////////////////////////////////////////
export class Page extends ZenPage
@@ -44,8 +44,6 @@ export class Page extends ZenPage
// Cache Namespaces
var section = this._collapsible_section("Cache Namespaces");
- section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all());
-
var columns = [
"namespace",
"dir",
@@ -56,31 +54,30 @@ export class Page extends ZenPage
"actions",
];
- var zcache_info = await new Fetcher().resource("/z$/").json();
this._cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_AlignNumeric);
- for (const namespace of zcache_info["Namespaces"] || [])
- {
- new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => {
- const row = this._cache_table.add_row(
- "",
- data["Configuration"]["RootDir"],
- data["Buckets"].length,
- data["EntryCount"],
- Friendly.bytes(data["StorageSize"].DiskSize),
- Friendly.bytes(data["StorageSize"].MemorySize)
- );
- var cell = row.get_cell(0);
- cell.tag().text(namespace).on_click(() => this.view_namespace(namespace));
-
- cell = row.get_cell(-1);
- const action_tb = new Toolbar(cell, true);
- action_tb.left().add("view").on_click(() => this.view_namespace(namespace));
- action_tb.left().add("drop").on_click(() => this.drop_namespace(namespace));
-
- row.attr("zs_name", namespace);
- });
- }
+ this._cache_pager = new Pager(section, 25, () => this._render_cache_page(),
+ Pager.make_search_fn(() => this._cache_data, item => item.namespace));
+ const cache_drop_link = document.createElement("span");
+ cache_drop_link.className = "dropall zen_action";
+ cache_drop_link.style.position = "static";
+ cache_drop_link.textContent = "drop-all";
+ cache_drop_link.addEventListener("click", () => this.drop_all());
+ this._cache_pager.prepend(cache_drop_link);
+
+ const loading = Pager.loading(section);
+ const zcache_info = await new Fetcher().resource("/z$/").json();
+ const namespaces = zcache_info["Namespaces"] || [];
+ const results = await Promise.allSettled(
+ namespaces.map(ns => new Fetcher().resource(`/z$/${ns}/`).json().then(data => ({ namespace: ns, data })))
+ );
+ this._cache_data = results
+ .filter(r => r.status === "fulfilled")
+ .map(r => r.value)
+ .sort((a, b) => a.namespace.localeCompare(b.namespace));
+ this._cache_pager.set_total(this._cache_data.length);
+ this._render_cache_page();
+ loading.remove();
// Namespace detail area (inside namespaces section so it collapses together)
this._namespace_host = section;
@@ -95,84 +92,79 @@ export class Page extends ZenPage
}
}
- _collapsible_section(name)
+ _render_cache_page()
{
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
+ const { start, end } = this._cache_pager.page_range();
+ this._cache_table.clear(start);
+ for (let i = start; i < end; i++)
+ {
+ const item = this._cache_data[i];
+ const data = item.data;
+ const row = this._cache_table.add_row(
+ "",
+ data["Configuration"]["RootDir"],
+ data["Buckets"].length,
+ data["EntryCount"],
+ Friendly.bytes(data["StorageSize"].DiskSize),
+ Friendly.bytes(data["StorageSize"].MemorySize)
+ );
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
+ const cell = row.get_cell(0);
+ cell.tag().text(item.namespace).on_click(() => this.view_namespace(item.namespace));
+ add_copy_button(cell.inner(), item.namespace);
+ add_copy_button(row.get_cell(1).inner(), data["Configuration"]["RootDir"]);
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
+ const action_cell = row.get_cell(-1);
+ const action_tb = new Toolbar(action_cell, true);
+ action_tb.left().add("view").on_click(() => this.view_namespace(item.namespace));
+ action_tb.left().add("drop").on_click(() => this.drop_namespace(item.namespace));
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
+ row.attr("zs_name", item.namespace);
+ }
}
_render_stats(stats)
{
+ stats = this._merge_last_stats(stats);
const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
const grid = this._stats_grid;
- this._last_stats = stats;
grid.inner().innerHTML = "";
// Store I/O tile
{
- const store = safe(stats, "cache.store");
- if (store)
- {
- const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed");
- if (this._selected_category === "store") tile.classify("stats-tile-selected");
- tile.on_click(() => this._select_category("store"));
- tile.tag().classify("card-title").text("Store I/O");
- const columns = tile.tag().classify("tile-columns");
-
- const left = columns.tag().classify("tile-metrics");
- const storeHits = store.hits || 0;
- const storeMisses = store.misses || 0;
- const storeTotal = storeHits + storeMisses;
- const storeRatio = storeTotal > 0 ? ((storeHits / storeTotal) * 100).toFixed(1) + "%" : "-";
- this._metric(left, storeRatio, "store hit ratio", true);
- this._metric(left, Friendly.sep(storeHits), "hits");
- this._metric(left, Friendly.sep(storeMisses), "misses");
- this._metric(left, Friendly.sep(store.writes || 0), "writes");
- this._metric(left, Friendly.sep(store.rejected_reads || 0), "rejected reads");
- this._metric(left, Friendly.sep(store.rejected_writes || 0), "rejected writes");
-
- const right = columns.tag().classify("tile-metrics");
- const readRateMean = safe(store, "read.bytes.rate_mean") || 0;
- const readRate1 = safe(store, "read.bytes.rate_1") || 0;
- const readRate5 = safe(store, "read.bytes.rate_5") || 0;
- const writeRateMean = safe(store, "write.bytes.rate_mean") || 0;
- const writeRate1 = safe(store, "write.bytes.rate_1") || 0;
- const writeRate5 = safe(store, "write.bytes.rate_5") || 0;
- this._metric(right, Friendly.bytes(readRateMean) + "/s", "read rate (mean)", true);
- this._metric(right, Friendly.bytes(readRate1) + "/s", "read rate (1m)");
- this._metric(right, Friendly.bytes(readRate5) + "/s", "read rate (5m)");
- this._metric(right, Friendly.bytes(writeRateMean) + "/s", "write rate (mean)");
- this._metric(right, Friendly.bytes(writeRate1) + "/s", "write rate (1m)");
- this._metric(right, Friendly.bytes(writeRate5) + "/s", "write rate (5m)");
- }
+ const store = safe(stats, "cache.store") || {};
+ const tile = grid.tag().classify("card").classify("stats-tile").classify("stats-tile-detailed");
+ if (this._selected_category === "store") tile.classify("stats-tile-selected");
+ tile.on_click(() => this._select_category("store"));
+ tile.tag().classify("card-title").text("Store I/O");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const storeHits = store.hits || 0;
+ const storeMisses = store.misses || 0;
+ const storeTotal = storeHits + storeMisses;
+ const storeRatio = storeTotal > 0 ? ((storeHits / storeTotal) * 100).toFixed(1) + "%" : "-";
+ this._metric(left, storeRatio, "store hit ratio", true);
+ this._metric(left, Friendly.sep(storeHits), "hits");
+ this._metric(left, Friendly.sep(storeMisses), "misses");
+ this._metric(left, Friendly.sep(store.writes || 0), "writes");
+ this._metric(left, Friendly.sep(store.rejected_reads || 0), "rejected reads");
+ this._metric(left, Friendly.sep(store.rejected_writes || 0), "rejected writes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const readRateMean = safe(store, "read.bytes.rate_mean") || 0;
+ const readRate1 = safe(store, "read.bytes.rate_1") || 0;
+ const readRate5 = safe(store, "read.bytes.rate_5") || 0;
+ const writeRateMean = safe(store, "write.bytes.rate_mean") || 0;
+ const writeRate1 = safe(store, "write.bytes.rate_1") || 0;
+ const writeRate5 = safe(store, "write.bytes.rate_5") || 0;
+ this._metric(right, Friendly.bytes(readRateMean) + "/s", "read rate (mean)", true);
+ this._metric(right, Friendly.bytes(readRate1) + "/s", "read rate (1m)");
+ this._metric(right, Friendly.bytes(readRate5) + "/s", "read rate (5m)");
+ this._metric(right, Friendly.bytes(writeRateMean) + "/s", "write rate (mean)");
+ this._metric(right, Friendly.bytes(writeRate1) + "/s", "write rate (1m)");
+ this._metric(right, Friendly.bytes(writeRate5) + "/s", "write rate (5m)");
}
// Hit/Miss tile
@@ -208,89 +200,83 @@ export class Page extends ZenPage
// HTTP Requests tile
{
- const req = safe(stats, "requests");
- if (req)
- {
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("HTTP Requests");
- const columns = tile.tag().classify("tile-columns");
+ const req = safe(stats, "requests") || {};
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("HTTP Requests");
+ const columns = tile.tag().classify("tile-columns");
- const left = columns.tag().classify("tile-metrics");
- const reqData = req.requests || req;
- this._metric(left, Friendly.sep(reqData.count || 0), "total requests", true);
- if (reqData.rate_mean > 0)
- {
- this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)");
- }
- if (reqData.rate_1 > 0)
- {
- this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)");
- }
- if (reqData.rate_5 > 0)
- {
- this._metric(left, Friendly.sep(reqData.rate_5, 1) + "/s", "req/sec (5m)");
- }
- if (reqData.rate_15 > 0)
- {
- this._metric(left, Friendly.sep(reqData.rate_15, 1) + "/s", "req/sec (15m)");
- }
- const badRequests = safe(stats, "cache.badrequestcount") || 0;
- this._metric(left, Friendly.sep(badRequests), "bad requests");
+ const left = columns.tag().classify("tile-metrics");
+ const reqData = req.requests || req;
+ this._metric(left, Friendly.sep(reqData.count || 0), "total requests", true);
+ if (reqData.rate_mean > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_mean, 1) + "/s", "req/sec (mean)");
+ }
+ if (reqData.rate_1 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_1, 1) + "/s", "req/sec (1m)");
+ }
+ if (reqData.rate_5 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_5, 1) + "/s", "req/sec (5m)");
+ }
+ if (reqData.rate_15 > 0)
+ {
+ this._metric(left, Friendly.sep(reqData.rate_15, 1) + "/s", "req/sec (15m)");
+ }
+ const badRequests = safe(stats, "cache.badrequestcount") || 0;
+ this._metric(left, Friendly.sep(badRequests), "bad requests");
- const right = columns.tag().classify("tile-metrics");
- this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true);
- if (reqData.t_p75)
- {
- this._metric(right, Friendly.duration(reqData.t_p75), "p75");
- }
- if (reqData.t_p95)
- {
- this._metric(right, Friendly.duration(reqData.t_p95), "p95");
- }
- if (reqData.t_p99)
- {
- this._metric(right, Friendly.duration(reqData.t_p99), "p99");
- }
- if (reqData.t_p999)
- {
- this._metric(right, Friendly.duration(reqData.t_p999), "p999");
- }
- if (reqData.t_max)
- {
- this._metric(right, Friendly.duration(reqData.t_max), "max");
- }
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, Friendly.duration(reqData.t_avg || 0), "avg latency", true);
+ if (reqData.t_p75)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p75), "p75");
+ }
+ if (reqData.t_p95)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p95), "p95");
+ }
+ if (reqData.t_p99)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p99), "p99");
+ }
+ if (reqData.t_p999)
+ {
+ this._metric(right, Friendly.duration(reqData.t_p999), "p999");
+ }
+ if (reqData.t_max)
+ {
+ this._metric(right, Friendly.duration(reqData.t_max), "max");
}
}
// RPC tile
{
- const rpc = safe(stats, "cache.rpc");
- if (rpc)
- {
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("RPC");
- const columns = tile.tag().classify("tile-columns");
+ const rpc = safe(stats, "cache.rpc") || {};
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("RPC");
+ const columns = tile.tag().classify("tile-columns");
- const left = columns.tag().classify("tile-metrics");
- this._metric(left, Friendly.sep(rpc.count || 0), "rpc calls", true);
- this._metric(left, Friendly.sep(rpc.ops || 0), "batch ops");
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, Friendly.sep(rpc.count || 0), "rpc calls", true);
+ this._metric(left, Friendly.sep(rpc.ops || 0), "batch ops");
- const right = columns.tag().classify("tile-metrics");
- if (rpc.records)
- {
- this._metric(right, Friendly.sep(rpc.records.count || 0), "record calls");
- this._metric(right, Friendly.sep(rpc.records.ops || 0), "record ops");
- }
- if (rpc.values)
- {
- this._metric(right, Friendly.sep(rpc.values.count || 0), "value calls");
- this._metric(right, Friendly.sep(rpc.values.ops || 0), "value ops");
- }
- if (rpc.chunks)
- {
- this._metric(right, Friendly.sep(rpc.chunks.count || 0), "chunk calls");
- this._metric(right, Friendly.sep(rpc.chunks.ops || 0), "chunk ops");
- }
+ const right = columns.tag().classify("tile-metrics");
+ if (rpc.records)
+ {
+ this._metric(right, Friendly.sep(rpc.records.count || 0), "record calls");
+ this._metric(right, Friendly.sep(rpc.records.ops || 0), "record ops");
+ }
+ if (rpc.values)
+ {
+ this._metric(right, Friendly.sep(rpc.values.count || 0), "value calls");
+ this._metric(right, Friendly.sep(rpc.values.ops || 0), "value ops");
+ }
+ if (rpc.chunks)
+ {
+ this._metric(right, Friendly.sep(rpc.chunks.count || 0), "chunk calls");
+ this._metric(right, Friendly.sep(rpc.chunks.ops || 0), "chunk ops");
}
}
@@ -313,7 +299,7 @@ export class Page extends ZenPage
this._metric(right, safe(stats, "cid.size.large") != null ? Friendly.bytes(safe(stats, "cid.size.large")) : "-", "cid large");
}
- // Upstream tile (only if upstream is active)
+ // Upstream tile (only shown when upstream is active)
{
const upstream = safe(stats, "upstream");
if (upstream)
@@ -644,10 +630,9 @@ export class Page extends ZenPage
async drop_all()
{
const drop = async () => {
- for (const row of this._cache_table)
+ for (const item of this._cache_data || [])
{
- const namespace = row.attr("zs_name");
- await new Fetcher().resource("z$", namespace).delete();
+ await new Fetcher().resource("z$", item.namespace).delete();
}
this.reload();
};
diff --git a/src/zenserver/frontend/html/pages/compute.js b/src/zenserver/frontend/html/pages/compute.js
index ab3d49c27..c2257029e 100644
--- a/src/zenserver/frontend/html/pages/compute.js
+++ b/src/zenserver/frontend/html/pages/compute.js
@@ -5,7 +5,7 @@
import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
-import { Table } from "../util/widgets.js"
+import { Table, add_copy_button } from "../util/widgets.js"
const MAX_HISTORY_POINTS = 60;
@@ -24,6 +24,12 @@ function formatTime(date)
return date.toLocaleTimeString([], { hour: "2-digit", minute: "2-digit", second: "2-digit" });
}
+function truncateHash(hash)
+{
+ if (!hash || hash.length <= 15) return hash;
+ return hash.slice(0, 6) + "\u2026" + hash.slice(-6);
+}
+
function formatDuration(startDate, endDate)
{
if (!startDate || !endDate) return "-";
@@ -100,39 +106,6 @@ export class Page extends ZenPage
}, 2000);
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
async _load_chartjs()
{
if (window.Chart)
@@ -338,11 +311,7 @@ export class Page extends ZenPage
{
const workerIds = data.workers || [];
- if (this._workers_table)
- {
- this._workers_table.clear();
- }
- else
+ if (!this._workers_table)
{
this._workers_table = this._workers_host.add_widget(
Table,
@@ -353,6 +322,7 @@ export class Page extends ZenPage
if (workerIds.length === 0)
{
+ this._workers_table.clear();
return;
}
@@ -382,6 +352,10 @@ export class Page extends ZenPage
id,
);
+ // Worker ID column: monospace for hex readability, copy button
+ row.get_cell(5).style("fontFamily", "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace");
+ add_copy_button(row.get_cell(5).inner(), id);
+
// Make name clickable to expand detail
const cell = row.get_cell(0);
cell.tag().text(name).on_click(() => this._toggle_worker_detail(id, desc));
@@ -551,7 +525,7 @@ export class Page extends ZenPage
: q.state === "draining" ? "draining"
: q.is_complete ? "complete" : "active";
- this._queues_table.add_row(
+ const qrow = this._queues_table.add_row(
id,
status,
String(q.active_count ?? 0),
@@ -561,6 +535,10 @@ export class Page extends ZenPage
String(q.cancelled_count ?? 0),
q.queue_token || "-",
);
+ if (q.queue_token)
+ {
+ add_copy_button(qrow.get_cell(7).inner(), q.queue_token);
+ }
}
}
@@ -579,6 +557,11 @@ export class Page extends ZenPage
["LSN", "queue", "status", "function", "started", "finished", "duration", "worker ID", "action ID"],
Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric, -1
);
+
+ // Right-align hash column headers to match data cells
+ const hdr = this._history_table.inner().firstElementChild;
+ hdr.children[7].style.textAlign = "right";
+ hdr.children[8].style.textAlign = "right";
}
// Entries arrive oldest-first; reverse to show newest at top
@@ -593,7 +576,10 @@ export class Page extends ZenPage
const startDate = filetimeToDate(entry.time_Running);
const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed);
- this._history_table.add_row(
+ const workerId = entry.workerId || "-";
+ const actionId = entry.actionId || "-";
+
+ const row = this._history_table.add_row(
lsn,
queueId,
status,
@@ -601,9 +587,17 @@ export class Page extends ZenPage
formatTime(startDate),
formatTime(endDate),
formatDuration(startDate, endDate),
- entry.workerId || "-",
- entry.actionId || "-",
+ truncateHash(workerId),
+ truncateHash(actionId),
);
+
+ // Hash columns: force right-align (AlignNumeric misses hex strings starting with a-f),
+ // use monospace for readability, and show full value on hover
+ const mono = "'SF Mono', 'Cascadia Mono', Consolas, 'DejaVu Sans Mono', monospace";
+ row.get_cell(7).style("textAlign", "right").style("fontFamily", mono).attr("title", workerId);
+ if (workerId !== "-") { add_copy_button(row.get_cell(7).inner(), workerId); }
+ row.get_cell(8).style("textAlign", "right").style("fontFamily", mono).attr("title", actionId);
+ if (actionId !== "-") { add_copy_button(row.get_cell(8).inner(), actionId); }
}
}
diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js
index 1e4c82e3f..e381f4a71 100644
--- a/src/zenserver/frontend/html/pages/entry.js
+++ b/src/zenserver/frontend/html/pages/entry.js
@@ -168,7 +168,7 @@ export class Page extends ZenPage
if (key === "cook.artifacts")
{
action_tb.left().add("view-raw").on_click(() => {
- window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/");
+ window.open("/" + ["prj", project, "oplog", oplog, value+".json"].join("/"), "_self");
});
}
diff --git a/src/zenserver/frontend/html/pages/hub.js b/src/zenserver/frontend/html/pages/hub.js
index 78e3a090c..b2bca9324 100644
--- a/src/zenserver/frontend/html/pages/hub.js
+++ b/src/zenserver/frontend/html/pages/hub.js
@@ -6,6 +6,7 @@ import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
import { Modal } from "../util/modal.js"
+import { flash_highlight, copy_button } from "../util/widgets.js"
////////////////////////////////////////////////////////////////////////////////
const STABLE_STATES = new Set(["provisioned", "hibernated", "crashed"]);
@@ -20,6 +21,7 @@ function _btn_enabled(state, action)
if (action === "hibernate") { return state === "provisioned"; }
if (action === "wake") { return state === "hibernated"; }
if (action === "deprovision") { return _is_actionable(state); }
+ if (action === "obliterate") { return _is_actionable(state); }
return false;
}
@@ -82,7 +84,7 @@ export class Page extends ZenPage
this.set_title("hub");
// Capacity
- const stats_section = this.add_section("Capacity");
+ const stats_section = this._collapsible_section("Hub Service Stats");
this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
// Modules
@@ -96,20 +98,24 @@ export class Page extends ZenPage
this._bulk_label.className = "module-bulk-label";
this._btn_bulk_hibernate = _make_bulk_btn("\u23F8", "Hibernate", () => this._exec_action("hibernate", [...this._selected]));
this._btn_bulk_wake = _make_bulk_btn("\u25B6", "Wake", () => this._exec_action("wake", [...this._selected]));
- this._btn_bulk_deprov = _make_bulk_btn("\u2715", "Deprovision",() => this._confirm_deprovision([...this._selected]));
+ this._btn_bulk_deprov = _make_bulk_btn("\u23F9", "Deprovision",() => this._confirm_deprovision([...this._selected]));
+ this._btn_bulk_oblit = _make_bulk_btn("\uD83D\uDD25", "Obliterate", () => this._confirm_obliterate([...this._selected]));
const bulk_sep = document.createElement("div");
bulk_sep.className = "module-bulk-sep";
this._btn_hibernate_all = _make_bulk_btn("\u23F8", "Hibernate All", () => this._confirm_all("hibernate", "Hibernate All"));
this._btn_wake_all = _make_bulk_btn("\u25B6", "Wake All", () => this._confirm_all("wake", "Wake All"));
- this._btn_deprov_all = _make_bulk_btn("\u2715", "Deprovision All",() => this._confirm_all("deprovision", "Deprovision All"));
+ this._btn_deprov_all = _make_bulk_btn("\u23F9", "Deprovision All",() => this._confirm_all("deprovision", "Deprovision All"));
+ this._btn_oblit_all = _make_bulk_btn("\uD83D\uDD25", "Obliterate All", () => this._confirm_obliterate(this._modules_data.map(m => m.moduleId)));
this._bulk_bar.appendChild(this._bulk_label);
this._bulk_bar.appendChild(this._btn_bulk_hibernate);
this._bulk_bar.appendChild(this._btn_bulk_wake);
this._bulk_bar.appendChild(this._btn_bulk_deprov);
+ this._bulk_bar.appendChild(this._btn_bulk_oblit);
this._bulk_bar.appendChild(bulk_sep);
this._bulk_bar.appendChild(this._btn_hibernate_all);
this._bulk_bar.appendChild(this._btn_wake_all);
this._bulk_bar.appendChild(this._btn_deprov_all);
+ this._bulk_bar.appendChild(this._btn_oblit_all);
mod_host.appendChild(this._bulk_bar);
// Module table
@@ -152,6 +158,38 @@ export class Page extends ZenPage
this._btn_next.className = "module-pager-btn";
this._btn_next.textContent = "Next \u2192";
this._btn_next.addEventListener("click", () => this._go_page(this._page + 1));
+ this._btn_provision = _make_bulk_btn("+", "Provision", () => this._show_provision_modal());
+ this._btn_obliterate = _make_bulk_btn("\uD83D\uDD25", "Obliterate", () => this._show_obliterate_modal());
+ this._search_input = document.createElement("input");
+ this._search_input.type = "text";
+ this._search_input.className = "module-pager-search";
+ this._search_input.placeholder = "Search module\u2026";
+ this._search_input.addEventListener("keydown", (e) =>
+ {
+ if (e.key === "Enter")
+ {
+ const term = this._search_input.value.trim().toLowerCase();
+ if (!term) { return; }
+ const idx = this._modules_data.findIndex(m =>
+ (m.moduleId || "").toLowerCase().includes(term)
+ );
+ if (idx >= 0)
+ {
+ const id = this._modules_data[idx].moduleId;
+ this._navigate_to_module(id);
+ this._flash_module(id);
+ }
+ else
+ {
+ this._search_input.style.outline = "2px solid var(--theme_fail)";
+ setTimeout(() => { this._search_input.style.outline = ""; }, 1000);
+ }
+ }
+ });
+
+ pager.appendChild(this._btn_provision);
+ pager.appendChild(this._btn_obliterate);
+ pager.appendChild(this._search_input);
pager.appendChild(this._btn_prev);
pager.appendChild(this._pager_label);
pager.appendChild(this._btn_next);
@@ -164,8 +202,11 @@ export class Page extends ZenPage
this._row_cache = new Map(); // moduleId → row refs, for in-place DOM updates
this._updating = false;
this._page = 0;
- this._page_size = 50;
+ this._page_size = 25;
this._expanded = new Set(); // moduleIds with open metrics panel
+ this._pending_highlight = null; // moduleId to navigate+flash after next poll
+ this._pending_highlight_timer = null;
+ this._loading = mod_section.tag().classify("pager-loading").text("Loading\u2026").inner();
await this._update();
this._poll_timer = setInterval(() => this._update(), 2000);
@@ -184,6 +225,15 @@ export class Page extends ZenPage
this._render_capacity(stats);
this._render_modules(status);
+ if (this._loading) { this._loading.remove(); this._loading = null; }
+ if (this._pending_highlight && this._module_map.has(this._pending_highlight))
+ {
+ const id = this._pending_highlight;
+ this._pending_highlight = null;
+ clearTimeout(this._pending_highlight_timer);
+ this._navigate_to_module(id);
+ this._flash_module(id);
+ }
}
catch (e) { /* service unavailable */ }
finally { this._updating = false; }
@@ -203,27 +253,48 @@ export class Page extends ZenPage
{
const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Active Modules");
+ tile.tag().classify("card-title").text("Instances");
const body = tile.tag().classify("tile-metrics");
this._metric(body, Friendly.sep(current), "currently provisioned", true);
+ this._metric(body, Friendly.sep(max), "high watermark");
+ this._metric(body, Friendly.sep(limit), "maximum allowed");
+ if (limit > 0)
+ {
+ const pct = ((current / limit) * 100).toFixed(0) + "%";
+ this._metric(body, pct, "utilization");
+ }
}
+ const machine = data.machine || {};
+ const limits = data.resource_limits || {};
+ if (machine.disk_total_bytes > 0 || machine.memory_total_mib > 0)
{
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Peak Modules");
- const body = tile.tag().classify("tile-metrics");
- this._metric(body, Friendly.sep(max), "high watermark", true);
- }
+ const disk_used = Math.max(0, (machine.disk_total_bytes || 0) - (machine.disk_free_bytes || 0));
+ const mem_used = Math.max(0, (machine.memory_total_mib || 0) - (machine.memory_avail_mib || 0)) * 1024 * 1024;
+ const vmem_used = Math.max(0, (machine.virtual_memory_total_mib || 0) - (machine.virtual_memory_avail_mib || 0)) * 1024 * 1024;
+ const disk_limit = limits.disk_bytes || 0;
+ const mem_limit = limits.memory_bytes || 0;
+ const disk_over = disk_limit > 0 && disk_used > disk_limit;
+ const mem_over = mem_limit > 0 && mem_used > mem_limit;
- {
const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Instance Limit");
- const body = tile.tag().classify("tile-metrics");
- this._metric(body, Friendly.sep(limit), "maximum allowed", true);
- if (limit > 0)
+ if (disk_over || mem_over) { tile.inner().setAttribute("data-over", "true"); }
+ tile.tag().classify("card-title").text("Resources");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ this._metric(left, Friendly.bytes(disk_used), "disk used", true);
+ this._metric(left, Friendly.bytes(machine.disk_total_bytes), "disk total");
+ if (disk_limit > 0) { this._metric(left, Friendly.bytes(disk_limit), "disk limit"); }
+
+ const right = columns.tag().classify("tile-metrics");
+ this._metric(right, Friendly.bytes(mem_used), "memory used", true);
+ this._metric(right, Friendly.bytes(machine.memory_total_mib * 1024 * 1024), "memory total");
+ if (mem_limit > 0) { this._metric(right, Friendly.bytes(mem_limit), "memory limit"); }
+ if (machine.virtual_memory_total_mib > 0)
{
- const pct = ((current / limit) * 100).toFixed(0) + "%";
- this._metric(body, pct, "utilization");
+ this._metric(right, Friendly.bytes(vmem_used), "vmem used", true);
+ this._metric(right, Friendly.bytes(machine.virtual_memory_total_mib * 1024 * 1024), "vmem total");
}
}
}
@@ -274,7 +345,7 @@ export class Page extends ZenPage
row.idx.textContent = i + 1;
row.cb.checked = this._selected.has(id);
row.dot.setAttribute("data-state", state);
- if (state === "deprovisioning")
+ if (state === "deprovisioning" || state === "obliterating")
{
row.dot.setAttribute("data-prev-state", prev);
}
@@ -284,10 +355,20 @@ export class Page extends ZenPage
}
row.state_text.nodeValue = state;
row.port_text.nodeValue = m.port ? String(m.port) : "";
+ row.copy_port_btn.style.display = m.port ? "" : "none";
+ if (m.state_change_time)
+ {
+ const state_label = state.charAt(0).toUpperCase() + state.slice(1);
+ row.state_since_label.textContent = state_label + " since";
+ row.state_age_label.textContent = state_label + " for";
+ row.state_since_node.nodeValue = m.state_change_time;
+ row.state_age_node.nodeValue = Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime());
+ }
row.btn_open.disabled = state !== "provisioned";
row.btn_hibernate.disabled = !_btn_enabled(state, "hibernate");
row.btn_wake.disabled = !_btn_enabled(state, "wake");
row.btn_deprov.disabled = !_btn_enabled(state, "deprovision");
+ row.btn_oblit.disabled = !_btn_enabled(state, "obliterate");
if (m.process_metrics)
{
@@ -347,6 +428,8 @@ export class Page extends ZenPage
id_wrap.style.cssText = "display:inline-flex;align-items:center;font-family:monospace;font-size:14px;";
id_wrap.appendChild(btn_expand);
id_wrap.appendChild(document.createTextNode("\u00A0" + id));
+ const copy_id_btn = copy_button(id);
+ id_wrap.appendChild(copy_id_btn);
td_id.appendChild(id_wrap);
tr.appendChild(td_id);
@@ -354,7 +437,7 @@ export class Page extends ZenPage
const dot = document.createElement("span");
dot.className = "module-state-dot";
dot.setAttribute("data-state", state);
- if (state === "deprovisioning")
+ if (state === "deprovisioning" || state === "obliterating")
{
dot.setAttribute("data-prev-state", prev);
}
@@ -368,27 +451,33 @@ export class Page extends ZenPage
td_port.style.cssText = "font-variant-numeric:tabular-nums;";
const port_node = document.createTextNode(port ? String(port) : "");
td_port.appendChild(port_node);
+ const copy_port_btn = copy_button(() => port_node.nodeValue);
+ copy_port_btn.style.display = port ? "" : "none";
+ td_port.appendChild(copy_port_btn);
tr.appendChild(td_port);
const td_action = document.createElement("td");
td_action.className = "module-action-cell";
const [wrap_o, btn_o] = _make_action_btn("\u2197", "Open dashboard", () => {
- window.open(`${window.location.protocol}//${window.location.hostname}:${port}`, "_blank");
+ window.open(`/hub/proxy/${port}/dashboard/`, "_blank");
});
btn_o.disabled = state !== "provisioned";
const [wrap_h, btn_h] = _make_action_btn("\u23F8", "Hibernate", () => this._post_module_action(id, "hibernate").then(() => this._update()));
const [wrap_w, btn_w] = _make_action_btn("\u25B6", "Wake", () => this._post_module_action(id, "wake").then(() => this._update()));
- const [wrap_d, btn_d] = _make_action_btn("\u2715", "Deprovision", () => this._confirm_deprovision([id]));
+ const [wrap_d, btn_d] = _make_action_btn("\u23F9", "Deprovision", () => this._confirm_deprovision([id]));
+ const [wrap_x, btn_x] = _make_action_btn("\uD83D\uDD25", "Obliterate", () => this._confirm_obliterate([id]));
btn_h.disabled = !_btn_enabled(state, "hibernate");
btn_w.disabled = !_btn_enabled(state, "wake");
btn_d.disabled = !_btn_enabled(state, "deprovision");
+ btn_x.disabled = !_btn_enabled(state, "obliterate");
td_action.appendChild(wrap_h);
td_action.appendChild(wrap_w);
td_action.appendChild(wrap_d);
+ td_action.appendChild(wrap_x);
td_action.appendChild(wrap_o);
tr.appendChild(td_action);
- // Build metrics grid from process_metrics keys.
+ // Build metrics grid: fixed state-time rows followed by process_metrics keys.
// Keys are split into two halves and interleaved so the grid fills
// top-to-bottom in the left column before continuing in the right column.
const metric_nodes = new Map();
@@ -396,6 +485,28 @@ export class Page extends ZenPage
metrics_td.colSpan = 6;
const metrics_grid = document.createElement("div");
metrics_grid.className = "module-metrics-grid";
+
+ const _add_fixed_pair = (label, value_str) => {
+ const label_el = document.createElement("span");
+ label_el.className = "module-metrics-label";
+ label_el.textContent = label;
+ const value_node = document.createTextNode(value_str);
+ const value_el = document.createElement("span");
+ value_el.className = "module-metrics-value";
+ value_el.appendChild(value_node);
+ metrics_grid.appendChild(label_el);
+ metrics_grid.appendChild(value_el);
+ return { label_el, value_node };
+ };
+
+ const state_label = m.state ? m.state.charAt(0).toUpperCase() + m.state.slice(1) : "State";
+ const state_since_str = m.state_change_time || "";
+ const state_age_str = m.state_change_time
+ ? Friendly.timespan(Date.now() - new Date(m.state_change_time).getTime())
+ : "";
+ const { label_el: state_since_label, value_node: state_since_node } = _add_fixed_pair(state_label + " since", state_since_str);
+ const { label_el: state_age_label, value_node: state_age_node } = _add_fixed_pair(state_label + " for", state_age_str);
+
const keys = Object.keys(m.process_metrics || {});
const half = Math.ceil(keys.length / 2);
const add_metric_pair = (key) => {
@@ -423,7 +534,7 @@ export class Page extends ZenPage
metrics_td.appendChild(metrics_grid);
metrics_tr.appendChild(metrics_td);
- row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, metric_nodes };
+ row = { tr, metrics_tr, idx: td_idx, cb, dot, state_text: state_node, port_text: port_node, copy_port_btn, btn_expand, btn_open: btn_o, btn_hibernate: btn_h, btn_wake: btn_w, btn_deprov: btn_d, btn_oblit: btn_x, metric_nodes, state_since_node, state_age_node, state_since_label, state_age_label };
this._row_cache.set(id, row);
}
@@ -533,6 +644,7 @@ export class Page extends ZenPage
this._btn_bulk_hibernate.disabled = !this._all_selected_in_state("provisioned");
this._btn_bulk_wake.disabled = !this._all_selected_in_state("hibernated");
this._btn_bulk_deprov.disabled = selected === 0;
+ this._btn_bulk_oblit.disabled = selected === 0;
this._select_all_cb.disabled = total === 0;
this._select_all_cb.checked = selected === total && total > 0;
@@ -545,6 +657,7 @@ export class Page extends ZenPage
this._btn_hibernate_all.disabled = empty;
this._btn_wake_all.disabled = empty;
this._btn_deprov_all.disabled = empty;
+ this._btn_oblit_all.disabled = empty;
}
_on_select_all()
@@ -590,6 +703,35 @@ export class Page extends ZenPage
.option("Deprovision", () => this._exec_action("deprovision", ids));
}
+ _confirm_obliterate(ids)
+ {
+ const warn = "\uD83D\uDD25 WARNING: This action is irreversible! \uD83D\uDD25";
+ const detail = "All local and backend data will be permanently destroyed.\nThis cannot be undone.";
+ let message;
+ if (ids.length === 1)
+ {
+ const id = ids[0];
+ const state = this._module_state(id) || "unknown";
+ message = `${warn}\n\n${detail}\n\nModule ID: ${id}\nCurrent state: ${state}`;
+ }
+ else
+ {
+ message = `${warn}\n\nObliterate ${ids.length} modules.\n\n${detail}`;
+ }
+
+ new Modal()
+ .title("\uD83D\uDD25 Obliterate")
+ .message(message)
+ .option("Cancel", null)
+ .option("\uD83D\uDD25 Obliterate", () => this._exec_obliterate(ids));
+ }
+
+ async _exec_obliterate(ids)
+ {
+ await Promise.allSettled(ids.map(id => fetch(`/hub/modules/${encodeURIComponent(id)}`, { method: "DELETE" })));
+ await this._update();
+ }
+
_confirm_all(action, label)
{
// Capture IDs at modal-open time so action targets the displayed list
@@ -614,4 +756,191 @@ export class Page extends ZenPage
await fetch(`/hub/modules/${moduleId}/${action}`, { method: "POST" });
}
+ _show_module_input_modal({ title, submit_label, warning, on_submit })
+ {
+ const MODULE_ID_RE = /^[A-Za-z0-9][A-Za-z0-9-]*$/;
+
+ const overlay = document.createElement("div");
+ overlay.className = "zen_modal";
+
+ const bg = document.createElement("div");
+ bg.className = "zen_modal_bg";
+ bg.addEventListener("click", () => overlay.remove());
+ overlay.appendChild(bg);
+
+ const dialog = document.createElement("div");
+ overlay.appendChild(dialog);
+
+ const title_el = document.createElement("div");
+ title_el.className = "zen_modal_title";
+ title_el.textContent = title;
+ dialog.appendChild(title_el);
+
+ const content = document.createElement("div");
+ content.className = "zen_modal_message";
+ content.style.textAlign = "center";
+
+ if (warning)
+ {
+ const warn = document.createElement("div");
+ warn.style.cssText = "color:var(--theme_fail);font-weight:bold;margin-bottom:12px;";
+ warn.textContent = warning;
+ content.appendChild(warn);
+ }
+
+ const input = document.createElement("input");
+ input.type = "text";
+ input.placeholder = "module-name";
+ input.style.cssText = "width:100%;font-size:14px;padding:8px 12px;";
+ content.appendChild(input);
+
+ const error_div = document.createElement("div");
+ error_div.style.cssText = "color:var(--theme_fail);font-size:12px;margin-top:8px;min-height:1.2em;";
+ content.appendChild(error_div);
+
+ dialog.appendChild(content);
+
+ const buttons = document.createElement("div");
+ buttons.className = "zen_modal_buttons";
+
+ const btn_cancel = document.createElement("div");
+ btn_cancel.textContent = "Cancel";
+ btn_cancel.addEventListener("click", () => overlay.remove());
+
+ const btn_submit = document.createElement("div");
+ btn_submit.textContent = submit_label;
+
+ buttons.appendChild(btn_cancel);
+ buttons.appendChild(btn_submit);
+ dialog.appendChild(buttons);
+
+ let submitting = false;
+
+ const set_submit_enabled = (enabled) => {
+ btn_submit.style.opacity = enabled ? "" : "0.4";
+ btn_submit.style.pointerEvents = enabled ? "" : "none";
+ };
+
+ set_submit_enabled(false);
+
+ const validate = () => {
+ if (submitting) { return false; }
+ const val = input.value.trim();
+ if (val.length === 0)
+ {
+ error_div.textContent = "";
+ set_submit_enabled(false);
+ return false;
+ }
+ if (!MODULE_ID_RE.test(val))
+ {
+ error_div.textContent = "Only letters, numbers, and hyphens allowed (must start with a letter or number)";
+ set_submit_enabled(false);
+ return false;
+ }
+ error_div.textContent = "";
+ set_submit_enabled(true);
+ return true;
+ };
+
+ input.addEventListener("input", validate);
+
+ const submit = async () => {
+ if (submitting) { return; }
+ const moduleId = input.value.trim();
+ if (!MODULE_ID_RE.test(moduleId)) { return; }
+
+ submitting = true;
+ set_submit_enabled(false);
+ error_div.textContent = "";
+
+ try
+ {
+ const ok = await on_submit(moduleId);
+ if (ok)
+ {
+ overlay.remove();
+ await this._update();
+ return;
+ }
+ }
+ catch (e)
+ {
+ error_div.textContent = e.message || "Request failed";
+ }
+ submitting = false;
+ set_submit_enabled(true);
+ };
+
+ btn_submit.addEventListener("click", submit);
+ input.addEventListener("keydown", (e) => {
+ if (e.key === "Enter" && validate()) { submit(); }
+ if (e.key === "Escape") { overlay.remove(); }
+ });
+
+ document.body.appendChild(overlay);
+ input.focus();
+
+ return { error_div };
+ }
+
+ _show_provision_modal()
+ {
+ const { error_div } = this._show_module_input_modal({
+ title: "Provision Module",
+ submit_label: "Provision",
+ on_submit: async (moduleId) => {
+ const resp = await fetch(`/hub/modules/${encodeURIComponent(moduleId)}/provision`, { method: "POST" });
+ if (!resp.ok)
+ {
+ const msg = await resp.text();
+ error_div.textContent = msg || ("HTTP " + resp.status);
+ return false;
+ }
+ // Endpoint returns compact binary (CbObjectWriter), not text
+ if (resp.status === 200 || resp.status === 202)
+ {
+ this._pending_highlight = moduleId;
+ this._pending_highlight_timer = setTimeout(() => { this._pending_highlight = null; }, 5000);
+ }
+ return true;
+ }
+ });
+ }
+
+ _show_obliterate_modal()
+ {
+ const { error_div } = this._show_module_input_modal({
+ title: "\uD83D\uDD25 Obliterate Module",
+ submit_label: "\uD83D\uDD25 Obliterate",
+ warning: "\uD83D\uDD25 WARNING: This action is irreversible! \uD83D\uDD25\nAll local and backend data will be permanently destroyed.",
+ on_submit: async (moduleId) => {
+ const resp = await fetch(`/hub/modules/${encodeURIComponent(moduleId)}`, { method: "DELETE" });
+ if (resp.ok)
+ {
+ return true;
+ }
+ const msg = await resp.text();
+ error_div.textContent = msg || ("HTTP " + resp.status);
+ return false;
+ }
+ });
+ }
+
+ _navigate_to_module(moduleId)
+ {
+ const idx = this._modules_data.findIndex(m => m.moduleId === moduleId);
+ if (idx >= 0)
+ {
+ this._page = Math.floor(idx / this._page_size);
+ this._render_page();
+ }
+ }
+
+ _flash_module(id)
+ {
+ const cached = this._row_cache.get(id);
+ if (cached) { flash_highlight(cached.tr); }
+ }
+
}
diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js
index 4a9290a3c..d11306998 100644
--- a/src/zenserver/frontend/html/pages/orchestrator.js
+++ b/src/zenserver/frontend/html/pages/orchestrator.js
@@ -5,7 +5,7 @@
import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
-import { Table } from "../util/widgets.js"
+import { Table, add_copy_button } from "../util/widgets.js"
////////////////////////////////////////////////////////////////////////////////
export class Page extends ZenPage
@@ -14,6 +14,14 @@ export class Page extends ZenPage
{
this.set_title("orchestrator");
+ // Provisioner section (hidden until data arrives)
+ this._prov_section = this._collapsible_section("Provisioner");
+ this._prov_section._parent.inner().style.display = "none";
+ this._prov_grid = null;
+ this._prov_target_dirty = false;
+ this._prov_commit_timer = null;
+ this._prov_last_target = null;
+
// Agents section
const agents_section = this._collapsible_section("Compute Agents");
this._agents_host = agents_section;
@@ -46,48 +54,16 @@ export class Page extends ZenPage
this._connect_ws();
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
async _fetch_all()
{
try
{
- const [agents, history, clients, client_history] = await Promise.all([
+ const [agents, history, clients, client_history, prov] = await Promise.all([
new Fetcher().resource("/orch/agents").json(),
new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null),
new Fetcher().resource("/orch/clients").json().catch(() => null),
new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null),
+ new Fetcher().resource("/orch/provisioner/status").json().catch(() => null),
]);
this._render_agents(agents);
@@ -103,6 +79,7 @@ export class Page extends ZenPage
{
this._render_client_history(client_history.client_events || []);
}
+ this._render_provisioner(prov);
}
catch (e) { /* service unavailable */ }
}
@@ -142,6 +119,7 @@ export class Page extends ZenPage
{
this._render_client_history(data.client_events);
}
+ this._render_provisioner(data.provisioner);
}
catch (e) { /* ignore parse errors */ }
};
@@ -189,7 +167,7 @@ export class Page extends ZenPage
return;
}
- let totalCpus = 0, totalWeightedCpu = 0;
+ let totalCpus = 0, activeCpus = 0, totalWeightedCpu = 0;
let totalMemUsed = 0, totalMemTotal = 0;
let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0;
let totalRecv = 0, totalSent = 0;
@@ -206,8 +184,14 @@ export class Page extends ZenPage
const completed = w.actions_completed || 0;
const recv = w.bytes_received || 0;
const sent = w.bytes_sent || 0;
+ const provisioner = w.provisioner || "";
+ const isProvisioned = provisioner !== "";
totalCpus += cpus;
+ if (w.provisioner_status === "active")
+ {
+ activeCpus += cpus;
+ }
if (cpus > 0 && typeof cpuUsage === "number")
{
totalWeightedCpu += cpuUsage * cpus;
@@ -242,12 +226,49 @@ export class Page extends ZenPage
cell.inner().textContent = "";
cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank");
}
+
+ // Visual treatment based on provisioner status
+ const provStatus = w.provisioner_status || "";
+ if (!isProvisioned)
+ {
+ row.inner().style.opacity = "0.45";
+ }
+ else
+ {
+ const hostCell = row.get_cell(0);
+ const el = hostCell.inner();
+ const badge = document.createElement("span");
+ const badgeBase = "display:inline-block;margin-left:6px;padding:1px 5px;border-radius:8px;" +
+ "font-size:9px;font-weight:600;color:#fff;vertical-align:middle;";
+
+ if (provStatus === "draining")
+ {
+ badge.textContent = "draining";
+ badge.style.cssText = badgeBase + "background:var(--theme_warn);";
+ row.inner().style.opacity = "0.6";
+ }
+ else if (provStatus === "active")
+ {
+ badge.textContent = provisioner;
+ badge.style.cssText = badgeBase + "background:#8957e5;";
+ }
+ else
+ {
+ badge.textContent = "deallocated";
+ badge.style.cssText = badgeBase + "background:var(--theme_fail);";
+ row.inner().style.opacity = "0.45";
+ }
+ el.appendChild(badge);
+ }
}
- // Total row
+ // Total row — show active / total in CPUs column
+ const cpuLabel = activeCpus < totalCpus
+ ? Friendly.sep(activeCpus) + " / " + Friendly.sep(totalCpus)
+ : Friendly.sep(totalCpus);
const total = this._agents_table.add_row(
"TOTAL",
- Friendly.sep(totalCpus),
+ cpuLabel,
"",
totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-",
Friendly.sep(totalQueues),
@@ -277,12 +298,13 @@ export class Page extends ZenPage
for (const c of clients)
{
- this._clients_table.add_row(
+ const crow = this._clients_table.add_row(
c.id || "",
c.hostname || "",
c.address || "",
this._format_last_seen(c.dt),
);
+ if (c.id) { add_copy_button(crow.get_cell(0).inner(), c.id); }
}
}
@@ -338,6 +360,154 @@ export class Page extends ZenPage
}
}
+ _render_provisioner(prov)
+ {
+ const container = this._prov_section._parent.inner();
+
+ if (!prov || !prov.name)
+ {
+ container.style.display = "none";
+ return;
+ }
+ container.style.display = "";
+
+ if (!this._prov_grid)
+ {
+ this._prov_grid = this._prov_section.tag().classify("grid").classify("stats-tiles");
+ this._prov_tiles = {};
+
+ // Target cores tile with editable input
+ const target_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ target_tile.tag().classify("card-title").text("Target Cores");
+ const target_body = target_tile.tag().classify("tile-metrics");
+ const target_m = target_body.tag().classify("tile-metric").classify("tile-metric-hero");
+ const input = document.createElement("input");
+ input.type = "number";
+ input.min = "0";
+ input.style.cssText = "width:100px;padding:4px 8px;border:1px solid var(--theme_g2);border-radius:4px;" +
+ "background:var(--theme_g4);color:var(--theme_bright);font-size:20px;font-weight:600;text-align:right;";
+ target_m.inner().appendChild(input);
+ target_m.tag().classify("metric-label").text("target");
+ this._prov_tiles.target_input = input;
+
+ input.addEventListener("focus", () => { this._prov_target_dirty = true; });
+ input.addEventListener("input", () => {
+ this._prov_target_dirty = true;
+ if (this._prov_commit_timer)
+ {
+ clearTimeout(this._prov_commit_timer);
+ }
+ this._prov_commit_timer = setTimeout(() => this._commit_provisioner_target(), 800);
+ });
+ input.addEventListener("keydown", (e) => {
+ if (e.key === "Enter")
+ {
+ if (this._prov_commit_timer)
+ {
+ clearTimeout(this._prov_commit_timer);
+ }
+ this._commit_provisioner_target();
+ input.blur();
+ }
+ });
+ input.addEventListener("blur", () => {
+ if (this._prov_commit_timer)
+ {
+ clearTimeout(this._prov_commit_timer);
+ }
+ this._commit_provisioner_target();
+ });
+
+ // Active cores
+ const active_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ active_tile.tag().classify("card-title").text("Active Cores");
+ const active_body = active_tile.tag().classify("tile-metrics");
+ this._prov_tiles.active = active_body;
+
+ // Estimated cores
+ const est_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ est_tile.tag().classify("card-title").text("Estimated Cores");
+ const est_body = est_tile.tag().classify("tile-metrics");
+ this._prov_tiles.estimated = est_body;
+
+ // Agents
+ const agents_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ agents_tile.tag().classify("card-title").text("Agents");
+ const agents_body = agents_tile.tag().classify("tile-metrics");
+ this._prov_tiles.agents = agents_body;
+
+ // Draining
+ const drain_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ drain_tile.tag().classify("card-title").text("Draining");
+ const drain_body = drain_tile.tag().classify("tile-metrics");
+ this._prov_tiles.draining = drain_body;
+ }
+
+ // Update values
+ const input = this._prov_tiles.target_input;
+ if (!this._prov_target_dirty && document.activeElement !== input)
+ {
+ input.value = prov.target_cores;
+ }
+ this._prov_last_target = prov.target_cores;
+
+ // Re-render metric tiles (clear and recreate content)
+ for (const key of ["active", "estimated", "agents", "draining"])
+ {
+ this._prov_tiles[key].inner().innerHTML = "";
+ }
+ this._metric(this._prov_tiles.active, Friendly.sep(prov.active_cores), "cores", true);
+ this._metric(this._prov_tiles.estimated, Friendly.sep(prov.estimated_cores), "cores", true);
+ this._metric(this._prov_tiles.agents, Friendly.sep(prov.agents), "active", true);
+ this._metric(this._prov_tiles.draining, Friendly.sep(prov.agents_draining || 0), "agents", true);
+ }
+
+ async _commit_provisioner_target()
+ {
+ const input = this._prov_tiles?.target_input;
+ if (!input || this._prov_committing)
+ {
+ return;
+ }
+ const value = parseInt(input.value, 10);
+ if (isNaN(value) || value < 0)
+ {
+ return;
+ }
+ if (value === this._prov_last_target)
+ {
+ this._prov_target_dirty = false;
+ return;
+ }
+ this._prov_committing = true;
+ try
+ {
+ const resp = await fetch("/orch/provisioner/target", {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ target_cores: value }),
+ });
+ if (resp.ok)
+ {
+ this._prov_target_dirty = false;
+ console.log("Target cores set to", value);
+ }
+ else
+ {
+ const text = await resp.text();
+ console.error("Failed to set target cores: HTTP", resp.status, text);
+ }
+ }
+ catch (e)
+ {
+ console.error("Failed to set target cores:", e);
+ }
+ finally
+ {
+ this._prov_committing = false;
+ }
+ }
+
_metric(parent, value, label, hero = false)
{
const m = parent.tag().classify("tile-metric");
diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js
index cf8d3e3dd..3653abb0e 100644
--- a/src/zenserver/frontend/html/pages/page.js
+++ b/src/zenserver/frontend/html/pages/page.js
@@ -6,6 +6,26 @@ import { WidgetHost } from "../util/widgets.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
+function _deep_merge_stats(base, update)
+{
+ const result = Object.assign({}, base);
+ for (const key of Object.keys(update))
+ {
+ const bv = result[key];
+ const uv = update[key];
+ if (uv && typeof uv === "object" && !Array.isArray(uv)
+ && bv && typeof bv === "object" && !Array.isArray(bv))
+ {
+ result[key] = _deep_merge_stats(bv, uv);
+ }
+ else
+ {
+ result[key] = uv;
+ }
+ }
+ return result;
+}
+
////////////////////////////////////////////////////////////////////////////////
export class PageBase extends WidgetHost
{
@@ -282,10 +302,7 @@ export class ZenPage extends PageBase
_render_http_requests_tile(grid, req, bad_requests = undefined)
{
- if (!req)
- {
- return;
- }
+ req = req || {};
const tile = grid.tag().classify("card").classify("stats-tile");
tile.tag().classify("card-title").text("HTTP Requests");
const columns = tile.tag().classify("tile-columns");
@@ -337,4 +354,47 @@ export class ZenPage extends PageBase
this._metric(right, Friendly.duration(reqData.t_max), "max");
}
}
+
+ _merge_last_stats(stats)
+ {
+ if (this._last_stats)
+ {
+ stats = _deep_merge_stats(this._last_stats, stats);
+ }
+ this._last_stats = stats;
+ return stats;
+ }
+
+ _collapsible_section(name)
+ {
+ const section = this.add_section(name);
+ const container = section._parent.inner();
+ const heading = container.firstElementChild;
+
+ heading.style.cursor = "pointer";
+ heading.style.userSelect = "none";
+
+ const indicator = document.createElement("span");
+ indicator.textContent = " \u25BC";
+ indicator.style.fontSize = "0.7em";
+ heading.appendChild(indicator);
+
+ let collapsed = false;
+ heading.addEventListener("click", (e) => {
+ if (e.target !== heading && e.target !== indicator)
+ {
+ return;
+ }
+ collapsed = !collapsed;
+ indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
+ let sibling = heading.nextElementSibling;
+ while (sibling)
+ {
+ sibling.style.display = collapsed ? "none" : "";
+ sibling = sibling.nextElementSibling;
+ }
+ });
+
+ return section;
+ }
}
diff --git a/src/zenserver/frontend/html/pages/projects.js b/src/zenserver/frontend/html/pages/projects.js
index 2469bf70b..2e76a80f1 100644
--- a/src/zenserver/frontend/html/pages/projects.js
+++ b/src/zenserver/frontend/html/pages/projects.js
@@ -6,7 +6,7 @@ import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
import { Modal } from "../util/modal.js"
-import { Table, Toolbar } from "../util/widgets.js"
+import { Table, Toolbar, Pager, add_copy_button } from "../util/widgets.js"
////////////////////////////////////////////////////////////////////////////////
export class Page extends ZenPage
@@ -39,8 +39,6 @@ export class Page extends ZenPage
// Projects list
var section = this._collapsible_section("Projects");
- section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all());
-
var columns = [
"name",
"project dir",
@@ -51,51 +49,21 @@ export class Page extends ZenPage
this._project_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight|Table.Flag_Sortable|Table.Flag_AlignNumeric);
- var projects = await new Fetcher().resource("/prj/list").json();
- projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0));
-
- for (const project of projects)
- {
- var row = this._project_table.add_row(
- "",
- "",
- "",
- "",
- );
-
- var cell = row.get_cell(0);
- cell.tag().text(project.Id).on_click(() => this.view_project(project.Id));
-
- if (project.ProjectRootDir)
- {
- row.get_cell(1).tag("a").text(project.ProjectRootDir)
- .attr("href", "vscode://" + project.ProjectRootDir.replace(/\\/g, "/"));
- }
- if (project.EngineRootDir)
- {
- row.get_cell(2).tag("a").text(project.EngineRootDir)
- .attr("href", "vscode://" + project.EngineRootDir.replace(/\\/g, "/"));
- }
-
- cell = row.get_cell(-1);
- const action_tb = new Toolbar(cell, true).left();
- action_tb.add("view").on_click(() => this.view_project(project.Id));
- action_tb.add("drop").on_click(() => this.drop_project(project.Id));
-
- row.attr("zs_name", project.Id);
-
- // Fetch project details to get oplog count
- new Fetcher().resource("prj", project.Id).json().then((info) => {
- const oplogs = info["oplogs"] || [];
- row.get_cell(3).text(Friendly.sep(oplogs.length)).style("textAlign", "right");
- // Right-align the corresponding header cell
- const header = this._project_table._element.firstElementChild;
- if (header && header.children[4])
- {
- header.children[4].style.textAlign = "right";
- }
- }).catch(() => {});
- }
+ this._project_pager = new Pager(section, 25, () => this._render_projects_page(),
+ Pager.make_search_fn(() => this._projects_data, p => p.Id));
+ const drop_link = document.createElement("span");
+ drop_link.className = "dropall zen_action";
+ drop_link.style.position = "static";
+ drop_link.textContent = "drop-all";
+ drop_link.addEventListener("click", () => this.drop_all());
+ this._project_pager.prepend(drop_link);
+
+ const loading = Pager.loading(section);
+ this._projects_data = await new Fetcher().resource("/prj/list").json();
+ this._projects_data.sort((a, b) => a.Id.localeCompare(b.Id));
+ this._project_pager.set_total(this._projects_data.length);
+ this._render_projects_page();
+ loading.remove();
// Project detail area (inside projects section so it collapses together)
this._project_host = section;
@@ -110,39 +78,6 @@ export class Page extends ZenPage
}
}
- _collapsible_section(name)
- {
- const section = this.add_section(name);
- const container = section._parent.inner();
- const heading = container.firstElementChild;
-
- heading.style.cursor = "pointer";
- heading.style.userSelect = "none";
-
- const indicator = document.createElement("span");
- indicator.textContent = " \u25BC";
- indicator.style.fontSize = "0.7em";
- heading.appendChild(indicator);
-
- let collapsed = false;
- heading.addEventListener("click", (e) => {
- if (e.target !== heading && e.target !== indicator)
- {
- return;
- }
- collapsed = !collapsed;
- indicator.textContent = collapsed ? " \u25B6" : " \u25BC";
- let sibling = heading.nextElementSibling;
- while (sibling)
- {
- sibling.style.display = collapsed ? "none" : "";
- sibling = sibling.nextElementSibling;
- }
- });
-
- return section;
- }
-
_clear_param(name)
{
this._params.delete(name);
@@ -153,6 +88,7 @@ export class Page extends ZenPage
_render_stats(stats)
{
+ stats = this._merge_last_stats(stats);
const safe = (obj, path) => path.split(".").reduce((a, b) => a && a[b], obj);
const grid = this._stats_grid;
@@ -163,54 +99,48 @@ export class Page extends ZenPage
// Store Operations tile
{
- const store = safe(stats, "store");
- if (store)
- {
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Store Operations");
- const columns = tile.tag().classify("tile-columns");
-
- const left = columns.tag().classify("tile-metrics");
- const proj = store.project || {};
- this._metric(left, Friendly.sep(proj.readcount || 0), "project reads", true);
- this._metric(left, Friendly.sep(proj.writecount || 0), "project writes");
- this._metric(left, Friendly.sep(proj.deletecount || 0), "project deletes");
-
- const right = columns.tag().classify("tile-metrics");
- const oplog = store.oplog || {};
- this._metric(right, Friendly.sep(oplog.readcount || 0), "oplog reads", true);
- this._metric(right, Friendly.sep(oplog.writecount || 0), "oplog writes");
- this._metric(right, Friendly.sep(oplog.deletecount || 0), "oplog deletes");
- }
+ const store = safe(stats, "store") || {};
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Store Operations");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const proj = store.project || {};
+ this._metric(left, Friendly.sep(proj.readcount || 0), "project reads", true);
+ this._metric(left, Friendly.sep(proj.writecount || 0), "project writes");
+ this._metric(left, Friendly.sep(proj.deletecount || 0), "project deletes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const oplog = store.oplog || {};
+ this._metric(right, Friendly.sep(oplog.readcount || 0), "oplog reads", true);
+ this._metric(right, Friendly.sep(oplog.writecount || 0), "oplog writes");
+ this._metric(right, Friendly.sep(oplog.deletecount || 0), "oplog deletes");
}
// Op & Chunk tile
{
- const store = safe(stats, "store");
- if (store)
- {
- const tile = grid.tag().classify("card").classify("stats-tile");
- tile.tag().classify("card-title").text("Ops & Chunks");
- const columns = tile.tag().classify("tile-columns");
-
- const left = columns.tag().classify("tile-metrics");
- const op = store.op || {};
- const opTotal = (op.hitcount || 0) + (op.misscount || 0);
- const opRatio = opTotal > 0 ? (((op.hitcount || 0) / opTotal) * 100).toFixed(1) + "%" : "-";
- this._metric(left, opRatio, "op hit ratio", true);
- this._metric(left, Friendly.sep(op.hitcount || 0), "op hits");
- this._metric(left, Friendly.sep(op.misscount || 0), "op misses");
- this._metric(left, Friendly.sep(op.writecount || 0), "op writes");
-
- const right = columns.tag().classify("tile-metrics");
- const chunk = store.chunk || {};
- const chunkTotal = (chunk.hitcount || 0) + (chunk.misscount || 0);
- const chunkRatio = chunkTotal > 0 ? (((chunk.hitcount || 0) / chunkTotal) * 100).toFixed(1) + "%" : "-";
- this._metric(right, chunkRatio, "chunk hit ratio", true);
- this._metric(right, Friendly.sep(chunk.hitcount || 0), "chunk hits");
- this._metric(right, Friendly.sep(chunk.misscount || 0), "chunk misses");
- this._metric(right, Friendly.sep(chunk.writecount || 0), "chunk writes");
- }
+ const store = safe(stats, "store") || {};
+ const tile = grid.tag().classify("card").classify("stats-tile");
+ tile.tag().classify("card-title").text("Ops & Chunks");
+ const columns = tile.tag().classify("tile-columns");
+
+ const left = columns.tag().classify("tile-metrics");
+ const op = store.op || {};
+ const opTotal = (op.hitcount || 0) + (op.misscount || 0);
+ const opRatio = opTotal > 0 ? (((op.hitcount || 0) / opTotal) * 100).toFixed(1) + "%" : "-";
+ this._metric(left, opRatio, "op hit ratio", true);
+ this._metric(left, Friendly.sep(op.hitcount || 0), "op hits");
+ this._metric(left, Friendly.sep(op.misscount || 0), "op misses");
+ this._metric(left, Friendly.sep(op.writecount || 0), "op writes");
+
+ const right = columns.tag().classify("tile-metrics");
+ const chunk = store.chunk || {};
+ const chunkTotal = (chunk.hitcount || 0) + (chunk.misscount || 0);
+ const chunkRatio = chunkTotal > 0 ? (((chunk.hitcount || 0) / chunkTotal) * 100).toFixed(1) + "%" : "-";
+ this._metric(right, chunkRatio, "chunk hit ratio", true);
+ this._metric(right, Friendly.sep(chunk.hitcount || 0), "chunk hits");
+ this._metric(right, Friendly.sep(chunk.misscount || 0), "chunk misses");
+ this._metric(right, Friendly.sep(chunk.writecount || 0), "chunk writes");
}
// Storage tile
@@ -231,6 +161,57 @@ export class Page extends ZenPage
}
}
+ _render_projects_page()
+ {
+ const { start, end } = this._project_pager.page_range();
+ this._project_table.clear(start);
+ for (let i = start; i < end; i++)
+ {
+ const project = this._projects_data[i];
+ const row = this._project_table.add_row(
+ "",
+ "",
+ "",
+ "",
+ );
+
+ const cell = row.get_cell(0);
+ cell.tag().text(project.Id).on_click(() => this.view_project(project.Id));
+ add_copy_button(cell.inner(), project.Id);
+
+ if (project.ProjectRootDir)
+ {
+ row.get_cell(1).tag("a").text(project.ProjectRootDir)
+ .attr("href", "vscode://" + project.ProjectRootDir.replace(/\\/g, "/"));
+ add_copy_button(row.get_cell(1).inner(), project.ProjectRootDir);
+ }
+ if (project.EngineRootDir)
+ {
+ row.get_cell(2).tag("a").text(project.EngineRootDir)
+ .attr("href", "vscode://" + project.EngineRootDir.replace(/\\/g, "/"));
+ add_copy_button(row.get_cell(2).inner(), project.EngineRootDir);
+ }
+
+ const action_cell = row.get_cell(-1);
+ const action_tb = new Toolbar(action_cell, true).left();
+ action_tb.add("view").on_click(() => this.view_project(project.Id));
+ action_tb.add("drop").on_click(() => this.drop_project(project.Id));
+
+ row.attr("zs_name", project.Id);
+
+ new Fetcher().resource("prj", project.Id).json().then((info) => {
+ const oplogs = info["oplogs"] || [];
+ row.get_cell(3).text(Friendly.sep(oplogs.length)).style("textAlign", "right");
+ }).catch(() => {});
+ }
+
+ const header = this._project_table._element.firstElementChild;
+ if (header && header.children[4])
+ {
+ header.children[4].style.textAlign = "right";
+ }
+ }
+
async view_project(project_id)
{
// Toggle off if already selected
@@ -351,10 +332,9 @@ export class Page extends ZenPage
async drop_all()
{
const drop = async () => {
- for (const row of this._project_table)
+ for (const project of this._projects_data || [])
{
- const project_id = row.attr("zs_name");
- await new Fetcher().resource("prj", project_id).delete();
+ await new Fetcher().resource("prj", project.Id).delete();
}
this.reload();
};
diff --git a/src/zenserver/frontend/html/pages/start.js b/src/zenserver/frontend/html/pages/start.js
index e5b4d14f1..d06040b2f 100644
--- a/src/zenserver/frontend/html/pages/start.js
+++ b/src/zenserver/frontend/html/pages/start.js
@@ -6,7 +6,7 @@ import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
import { Friendly } from "../util/friendly.js"
import { Modal } from "../util/modal.js"
-import { Table, Toolbar } from "../util/widgets.js"
+import { Table, Toolbar, Pager } from "../util/widgets.js"
////////////////////////////////////////////////////////////////////////////////
export class Page extends ZenPage
@@ -50,54 +50,40 @@ export class Page extends ZenPage
this._render_stats(all_stats);
// project list
- var project_table = null;
if (available.has("/prj/"))
{
var section = this.add_section("Cooked Projects");
- section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("projects"));
-
var columns = [
"name",
"project_dir",
"engine_dir",
"actions",
];
- project_table = section.add_widget(Table, columns);
-
- var projects = await new Fetcher().resource("/prj/list").json();
- projects.sort((a, b) => (b.LastAccessTime || 0) - (a.LastAccessTime || 0));
- projects = projects.slice(0, 25);
- projects.sort((a, b) => a.Id.localeCompare(b.Id));
-
- for (const project of projects)
- {
- var row = project_table.add_row(
- "",
- project.ProjectRootDir,
- project.EngineRootDir,
- );
-
- var cell = row.get_cell(0);
- cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id);
-
- var cell = row.get_cell(-1);
- var action_tb = new Toolbar(cell, true);
- action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id);
- action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id);
-
- row.attr("zs_name", project.Id);
- }
+ this._project_table = section.add_widget(Table, columns);
+
+ this._project_pager = new Pager(section, 25, () => this._render_projects_page(),
+ Pager.make_search_fn(() => this._projects_data, p => p.Id));
+ const drop_link = document.createElement("span");
+ drop_link.className = "dropall zen_action";
+ drop_link.style.position = "static";
+ drop_link.textContent = "drop-all";
+ drop_link.addEventListener("click", () => this.drop_all("projects"));
+ this._project_pager.prepend(drop_link);
+
+ const prj_loading = Pager.loading(section);
+ this._projects_data = await new Fetcher().resource("/prj/list").json();
+ this._projects_data.sort((a, b) => a.Id.localeCompare(b.Id));
+ this._project_pager.set_total(this._projects_data.length);
+ this._render_projects_page();
+ prj_loading.remove();
}
// cache
- var cache_table = null;
if (available.has("/z$/"))
{
var section = this.add_section("Cache");
- section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$"));
-
var columns = [
"namespace",
"dir",
@@ -107,30 +93,30 @@ export class Page extends ZenPage
"size mem",
"actions",
];
- var zcache_info = await new Fetcher().resource("/z$/").json();
- cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight);
- for (const namespace of zcache_info["Namespaces"] || [])
- {
- new Fetcher().resource(`/z$/${namespace}/`).json().then((data) => {
- const row = cache_table.add_row(
- "",
- data["Configuration"]["RootDir"],
- data["Buckets"].length,
- data["EntryCount"],
- Friendly.bytes(data["StorageSize"].DiskSize),
- Friendly.bytes(data["StorageSize"].MemorySize)
- );
- var cell = row.get_cell(0);
- cell.tag().text(namespace).on_click(() => this.view_zcache(namespace));
-
- cell = row.get_cell(-1);
- const action_tb = new Toolbar(cell, true);
- action_tb.left().add("view").on_click(() => this.view_zcache(namespace));
- action_tb.left().add("drop").on_click(() => this.drop_zcache(namespace));
-
- row.attr("zs_name", namespace);
- });
- }
+ this._cache_table = section.add_widget(Table, columns, Table.Flag_FitLeft|Table.Flag_PackRight);
+
+ this._cache_pager = new Pager(section, 25, () => this._render_cache_page(),
+ Pager.make_search_fn(() => this._cache_data, item => item.namespace));
+ const cache_drop_link = document.createElement("span");
+ cache_drop_link.className = "dropall zen_action";
+ cache_drop_link.style.position = "static";
+ cache_drop_link.textContent = "drop-all";
+ cache_drop_link.addEventListener("click", () => this.drop_all("z$"));
+ this._cache_pager.prepend(cache_drop_link);
+
+ const cache_loading = Pager.loading(section);
+ const zcache_info = await new Fetcher().resource("/z$/").json();
+ const namespaces = zcache_info["Namespaces"] || [];
+ const results = await Promise.allSettled(
+ namespaces.map(ns => new Fetcher().resource(`/z$/${ns}/`).json().then(data => ({ namespace: ns, data })))
+ );
+ this._cache_data = results
+ .filter(r => r.status === "fulfilled")
+ .map(r => r.value)
+ .sort((a, b) => a.namespace.localeCompare(b.namespace));
+ this._cache_pager.set_total(this._cache_data.length);
+ this._render_cache_page();
+ cache_loading.remove();
}
// version
@@ -139,15 +125,13 @@ export class Page extends ZenPage
version.param("detailed", "true");
version.text().then((data) => ver_tag.text(data));
- this._project_table = project_table;
- this._cache_table = cache_table;
-
// WebSocket for live stats updates
this.connect_stats_ws((all_stats) => this._render_stats(all_stats));
}
_render_stats(all_stats)
{
+ all_stats = this._merge_last_stats(all_stats);
const grid = this._stats_grid;
const safe_lookup = this._safe_lookup;
@@ -316,6 +300,60 @@ export class Page extends ZenPage
m.tag().classify("metric-label").text(label);
}
+ _render_projects_page()
+ {
+ const { start, end } = this._project_pager.page_range();
+ this._project_table.clear(start);
+ for (let i = start; i < end; i++)
+ {
+ const project = this._projects_data[i];
+ const row = this._project_table.add_row(
+ "",
+ project.ProjectRootDir,
+ project.EngineRootDir,
+ );
+
+ const cell = row.get_cell(0);
+ cell.tag().text(project.Id).on_click((x) => this.view_project(x), project.Id);
+
+ const action_cell = row.get_cell(-1);
+ const action_tb = new Toolbar(action_cell, true);
+ action_tb.left().add("view").on_click((x) => this.view_project(x), project.Id);
+ action_tb.left().add("drop").on_click((x) => this.drop_project(x), project.Id);
+
+ row.attr("zs_name", project.Id);
+ }
+ }
+
+ _render_cache_page()
+ {
+ const { start, end } = this._cache_pager.page_range();
+ this._cache_table.clear(start);
+ for (let i = start; i < end; i++)
+ {
+ const item = this._cache_data[i];
+ const data = item.data;
+ const row = this._cache_table.add_row(
+ "",
+ data["Configuration"]["RootDir"],
+ data["Buckets"].length,
+ data["EntryCount"],
+ Friendly.bytes(data["StorageSize"].DiskSize),
+ Friendly.bytes(data["StorageSize"].MemorySize)
+ );
+
+ const cell = row.get_cell(0);
+ cell.tag().text(item.namespace).on_click(() => this.view_zcache(item.namespace));
+
+ const action_cell = row.get_cell(-1);
+ const action_tb = new Toolbar(action_cell, true);
+ action_tb.left().add("view").on_click(() => this.view_zcache(item.namespace));
+ action_tb.left().add("drop").on_click(() => this.drop_zcache(item.namespace));
+
+ row.attr("zs_name", item.namespace);
+ }
+ }
+
view_stat(provider)
{
window.location = "?page=stat&provider=" + provider;
@@ -361,20 +399,18 @@ export class Page extends ZenPage
async drop_all_projects()
{
- for (const row of this._project_table)
+ for (const project of this._projects_data || [])
{
- const project_id = row.attr("zs_name");
- await new Fetcher().resource("prj", project_id).delete();
+ await new Fetcher().resource("prj", project.Id).delete();
}
this.reload();
}
async drop_all_zcache()
{
- for (const row of this._cache_table)
+ for (const item of this._cache_data || [])
{
- const namespace = row.attr("zs_name");
- await new Fetcher().resource("z$", namespace).delete();
+ await new Fetcher().resource("z$", item.namespace).delete();
}
this.reload();
}
diff --git a/src/zenserver/frontend/html/pages/workspaces.js b/src/zenserver/frontend/html/pages/workspaces.js
index d31fd7373..db02e8be1 100644
--- a/src/zenserver/frontend/html/pages/workspaces.js
+++ b/src/zenserver/frontend/html/pages/workspaces.js
@@ -4,6 +4,7 @@
import { ZenPage } from "./page.js"
import { Fetcher } from "../util/fetcher.js"
+import { copy_button } from "../util/widgets.js"
////////////////////////////////////////////////////////////////////////////////
export class Page extends ZenPage
@@ -13,7 +14,7 @@ export class Page extends ZenPage
this.set_title("workspaces");
// Workspace Service Stats
- const stats_section = this.add_section("Workspace Service Stats");
+ const stats_section = this._collapsible_section("Workspace Service Stats");
this._stats_grid = stats_section.tag().classify("grid").classify("stats-tiles");
const stats = await new Fetcher().resource("stats", "ws").json().catch(() => null);
@@ -157,6 +158,7 @@ export class Page extends ZenPage
id_wrap.className = "ws-id-wrap";
id_wrap.appendChild(btn_expand);
id_wrap.appendChild(document.createTextNode("\u00A0" + id));
+ id_wrap.appendChild(copy_button(id));
const td_id = document.createElement("td");
td_id.appendChild(id_wrap);
tr.appendChild(td_id);
@@ -200,6 +202,7 @@ export class Page extends ZenPage
_render_stats(stats)
{
+ stats = this._merge_last_stats(stats);
const grid = this._stats_grid;
grid.inner().innerHTML = "";
diff --git a/src/zenserver/frontend/html/util/widgets.js b/src/zenserver/frontend/html/util/widgets.js
index 17bd2fde7..651686a11 100644
--- a/src/zenserver/frontend/html/util/widgets.js
+++ b/src/zenserver/frontend/html/util/widgets.js
@@ -6,6 +6,58 @@ import { Component } from "./component.js"
import { Friendly } from "../util/friendly.js"
////////////////////////////////////////////////////////////////////////////////
+export function flash_highlight(element)
+{
+ if (!element) { return; }
+ element.classList.add("pager-search-highlight");
+ setTimeout(() => { element.classList.remove("pager-search-highlight"); }, 1500);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+export function copy_button(value_or_fn)
+{
+ if (!navigator.clipboard)
+ {
+ const stub = document.createElement("span");
+ stub.style.display = "none";
+ return stub;
+ }
+
+ let reset_timer = 0;
+ const btn = document.createElement("button");
+ btn.className = "zen-copy-btn";
+ btn.title = "Copy to clipboard";
+ btn.textContent = "\u29C9";
+ btn.addEventListener("click", async (e) => {
+ e.stopPropagation();
+ const v = typeof value_or_fn === "function" ? value_or_fn() : value_or_fn;
+ if (!v) { return; }
+ try
+ {
+ await navigator.clipboard.writeText(v);
+ clearTimeout(reset_timer);
+ btn.classList.add("zen-copy-ok");
+ btn.textContent = "\u2713";
+ reset_timer = setTimeout(() => { btn.classList.remove("zen-copy-ok"); btn.textContent = "\u29C9"; }, 800);
+ }
+ catch (_e) { /* clipboard not available */ }
+ });
+ return btn;
+}
+
+// Wraps the existing children of `element` plus a copy button into an
+// inline-flex nowrap container so the button never wraps to a new line.
+export function add_copy_button(element, value_or_fn)
+{
+ if (!navigator.clipboard) { return; }
+ const wrap = document.createElement("span");
+ wrap.className = "zen-copy-wrap";
+ while (element.firstChild) { wrap.appendChild(element.firstChild); }
+ wrap.appendChild(copy_button(value_or_fn));
+ element.appendChild(wrap);
+}
+
+////////////////////////////////////////////////////////////////////////////////
class Widget extends Component
{
}
@@ -402,6 +454,135 @@ export class ProgressBar extends Widget
////////////////////////////////////////////////////////////////////////////////
+export class Pager
+{
+ constructor(section, page_size, on_change, search_fn)
+ {
+ this._page = 0;
+ this._page_size = page_size;
+ this._total = 0;
+ this._on_change = on_change;
+ this._search_fn = search_fn || null;
+ this._search_input = null;
+
+ const pager = section.tag().classify("module-pager").inner();
+ this._btn_prev = document.createElement("button");
+ this._btn_prev.className = "module-pager-btn";
+ this._btn_prev.textContent = "\u2190 Prev";
+ this._btn_prev.addEventListener("click", () => this._go_page(this._page - 1));
+ this._label = document.createElement("span");
+ this._label.className = "module-pager-label";
+ this._btn_next = document.createElement("button");
+ this._btn_next.className = "module-pager-btn";
+ this._btn_next.textContent = "Next \u2192";
+ this._btn_next.addEventListener("click", () => this._go_page(this._page + 1));
+
+ if (this._search_fn)
+ {
+ this._search_input = document.createElement("input");
+ this._search_input.type = "text";
+ this._search_input.className = "module-pager-search";
+ this._search_input.placeholder = "Search\u2026";
+ this._search_input.addEventListener("keydown", (e) =>
+ {
+ if (e.key === "Enter")
+ {
+ this._do_search(this._search_input.value.trim());
+ }
+ });
+ pager.appendChild(this._search_input);
+ }
+
+ pager.appendChild(this._btn_prev);
+ pager.appendChild(this._label);
+ pager.appendChild(this._btn_next);
+ this._pager = pager;
+
+ this._update_ui();
+ }
+
+ prepend(element)
+ {
+ const ref = this._search_input || this._btn_prev;
+ this._pager.insertBefore(element, ref);
+ }
+
+ set_total(n)
+ {
+ this._total = n;
+ const max_page = Math.max(0, Math.ceil(n / this._page_size) - 1);
+ if (this._page > max_page)
+ {
+ this._page = max_page;
+ }
+ this._update_ui();
+ }
+
+ page_range()
+ {
+ const start = this._page * this._page_size;
+ const end = Math.min(start + this._page_size, this._total);
+ return { start, end };
+ }
+
+ _go_page(n)
+ {
+ const max = Math.max(0, Math.ceil(this._total / this._page_size) - 1);
+ this._page = Math.max(0, Math.min(n, max));
+ this._update_ui();
+ this._on_change();
+ }
+
+ _do_search(term)
+ {
+ if (!term || !this._search_fn)
+ {
+ return;
+ }
+ const result = this._search_fn(term);
+ if (!result)
+ {
+ this._search_input.style.outline = "2px solid var(--theme_fail)";
+ setTimeout(() => { this._search_input.style.outline = ""; }, 1000);
+ return;
+ }
+ this._go_page(Math.floor(result.index / this._page_size));
+ flash_highlight(this._pager.parentNode.querySelector(`[zs_name="${CSS.escape(result.name)}"]`));
+ }
+
+ _update_ui()
+ {
+ const total = this._total;
+ const page_count = Math.max(1, Math.ceil(total / this._page_size));
+ const start = this._page * this._page_size + 1;
+ const end = Math.min(start + this._page_size - 1, total);
+
+ this._btn_prev.disabled = this._page === 0;
+ this._btn_next.disabled = this._page >= page_count - 1;
+ this._label.textContent = total === 0
+ ? "No items"
+ : `${start}\u2013${end} of ${total}`;
+ }
+
+ static make_search_fn(get_data, get_key)
+ {
+ return (term) => {
+ const t = term.toLowerCase();
+ const data = get_data();
+ const i = data.findIndex(item => get_key(item).toLowerCase().includes(t));
+ return i < 0 ? null : { index: i, name: get_key(data[i]) };
+ };
+ }
+
+ static loading(section)
+ {
+ return section.tag().classify("pager-loading").text("Loading\u2026").inner();
+ }
+}
+
+
+
+////////////////////////////////////////////////////////////////////////////////
export class WidgetHost
{
constructor(parent, depth=1)
diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css
index d9f7491ea..d3c6c9036 100644
--- a/src/zenserver/frontend/html/zen.css
+++ b/src/zenserver/frontend/html/zen.css
@@ -816,6 +816,10 @@ zen-banner + zen-nav::part(nav-bar) {
border-color: var(--theme_p0);
}
+.stats-tile[data-over="true"] {
+ border-color: var(--theme_fail);
+}
+
.stats-tile-detailed {
position: relative;
}
@@ -1607,6 +1611,25 @@ tr:last-child td {
animation: module-dot-deprovisioning-from-provisioned 1s steps(1, end) infinite;
}
+@keyframes module-dot-obliterating-from-provisioned {
+ 0%, 59.9% { background: var(--theme_fail); }
+ 60%, 100% { background: var(--theme_ok); }
+}
+@keyframes module-dot-obliterating-from-hibernated {
+ 0%, 59.9% { background: var(--theme_fail); }
+ 60%, 100% { background: var(--theme_warn); }
+}
+
+.module-state-dot[data-state="obliterating"][data-prev-state="provisioned"] {
+ animation: module-dot-obliterating-from-provisioned 0.5s steps(1, end) infinite;
+}
+.module-state-dot[data-state="obliterating"][data-prev-state="hibernated"] {
+ animation: module-dot-obliterating-from-hibernated 0.5s steps(1, end) infinite;
+}
+.module-state-dot[data-state="obliterating"] {
+ animation: module-dot-obliterating-from-provisioned 0.5s steps(1, end) infinite;
+}
+
.module-action-cell {
white-space: nowrap;
display: flex;
@@ -1726,6 +1749,53 @@ tr:last-child td {
text-align: center;
}
+.module-pager-search {
+ font-size: 12px;
+ padding: 4px 8px;
+ width: 14em;
+ border: 1px solid var(--theme_g2);
+ border-radius: 4px;
+ background: var(--theme_g4);
+ color: var(--theme_g0);
+ outline: none;
+ transition: border-color 0.15s, outline 0.3s;
+}
+
+.module-pager-search:focus {
+ border-color: var(--theme_p0);
+}
+
+.module-pager-search::placeholder {
+ color: var(--theme_g1);
+}
+
+@keyframes pager-search-flash {
+ from { box-shadow: inset 0 0 0 100px var(--theme_p2); }
+ to { box-shadow: inset 0 0 0 100px transparent; }
+}
+
+.zen_table > .pager-search-highlight > div {
+ animation: pager-search-flash 1s linear forwards;
+}
+
+.module-table .pager-search-highlight td {
+ animation: pager-search-flash 1s linear forwards;
+}
+
+@keyframes pager-loading-pulse {
+ 0%, 100% { opacity: 0.6; }
+ 50% { opacity: 0.2; }
+}
+
+.pager-loading {
+ color: var(--theme_g1);
+ font-style: italic;
+ font-size: 14px;
+ font-weight: 600;
+ padding: 12px 0;
+ animation: pager-loading-pulse 1.5s ease-in-out infinite;
+}
+
.module-table td, .module-table th {
padding-top: 4px;
padding-bottom: 4px;
@@ -1746,6 +1816,35 @@ tr:last-child td {
color: var(--theme_bright);
}
+.zen-copy-btn {
+ background: transparent;
+ border: 1px solid var(--theme_g2);
+ border-radius: 4px;
+ color: var(--theme_g1);
+ cursor: pointer;
+ font-size: 12px;
+ line-height: 1;
+ padding: 2px 5px;
+ margin-left: 6px;
+ vertical-align: middle;
+ flex-shrink: 0;
+ transition: background 0.1s, color 0.1s;
+}
+.zen-copy-btn:hover {
+ background: var(--theme_g2);
+ color: var(--theme_bright);
+}
+.zen-copy-btn.zen-copy-ok {
+ color: var(--theme_ok);
+ border-color: var(--theme_ok);
+}
+
+.zen-copy-wrap {
+ display: inline-flex;
+ align-items: center;
+ white-space: nowrap;
+}
+
.module-metrics-row td {
padding: 6px 10px 10px 42px;
background: var(--theme_g3);
diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp
index c7c8687ca..27b92f33a 100644
--- a/src/zenserver/frontend/zipfs.cpp
+++ b/src/zenserver/frontend/zipfs.cpp
@@ -189,12 +189,12 @@ ZipFs::GetFile(const std::string_view& FileName) const
if (Item.CompressionMethod == 0)
{
- // Stored — point directly into the buffer
+ // Stored - point directly into the buffer
Item.View = MemoryView(FileData, Item.UncompressedSize);
}
else
{
- // Deflate — decompress using zlib
+ // Deflate - decompress using zlib
Item.DecompressedData = IoBuffer(Item.UncompressedSize);
z_stream Stream = {};
diff --git a/src/zenserver/hub/README.md b/src/zenserver/hub/README.md
index 322be3649..c75349fa5 100644
--- a/src/zenserver/hub/README.md
+++ b/src/zenserver/hub/README.md
@@ -3,23 +3,32 @@
The Zen Server can act in a "hub" mode. In this mode, the only services offered are the basic health
and diagnostic services alongside an API to provision and deprovision Storage server instances.
+A module ID is an alphanumeric identifier (hyphens allowed) that identifies a dataset, typically
+associated with a content plug-in module.
+
## Generic Server API
GET `/health` - returns an `OK!` payload when all enabled services are up and responding
## Hub API
-GET `{moduleid}` - alphanumeric identifier to identify a dataset (typically associated with a content plug-in module)
-
-GET `/hub/status` - obtain a summary of the currently live instances
+GET `/hub/status` - obtain a summary of all currently live instances
GET `/hub/modules/{moduleid}` - retrieve information about a module
+DELETE `/hub/modules/{moduleid}` - obliterate a module (permanently destroys all data)
+
POST `/hub/modules/{moduleid}/provision` - provision service for module
POST `/hub/modules/{moduleid}/deprovision` - deprovision service for module
-GET `/hub/stats` - retrieve stats for service
+POST `/hub/modules/{moduleid}/hibernate` - hibernate a provisioned module
+
+POST `/hub/modules/{moduleid}/wake` - wake a hibernated module
+
+GET `/stats/hub` - retrieve stats for the hub service
+
+`/hub/proxy/{port}/{path}` - reverse proxy to a child instance dashboard (all HTTP verbs)
## Hub Configuration
diff --git a/src/zenserver/hub/httphubservice.cpp b/src/zenserver/hub/httphubservice.cpp
index ebefcf2e3..e4b0c28d0 100644
--- a/src/zenserver/hub/httphubservice.cpp
+++ b/src/zenserver/hub/httphubservice.cpp
@@ -2,6 +2,7 @@
#include "httphubservice.h"
+#include "httpproxyhandler.h"
#include "hub.h"
#include "storageserverinstance.h"
@@ -43,10 +44,11 @@ namespace {
}
} // namespace
-HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService)
+HttpHubService::HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService)
: m_Hub(Hub)
, m_StatsService(StatsService)
, m_StatusService(StatusService)
+, m_Proxy(Proxy)
{
using namespace std::literals;
@@ -67,6 +69,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
return true;
});
+ m_Router.AddMatcher("port", [](std::string_view Str) -> bool {
+ if (Str.empty())
+ {
+ return false;
+ }
+ for (const auto C : Str)
+ {
+ if (!std::isdigit(C))
+ {
+ return false;
+ }
+ }
+ return true;
+ });
+
+ m_Router.AddMatcher("proxypath", [](std::string_view Str) -> bool { return !Str.empty(); });
+
m_Router.RegisterRoute(
"status",
[this](HttpRouterRequest& Req) {
@@ -78,6 +97,10 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
Obj << "moduleId" << ModuleId;
Obj << "state" << ToString(Info.State);
Obj << "port" << Info.Port;
+ if (Info.StateChangeTime != std::chrono::system_clock::time_point::min())
+ {
+ Obj << "state_change_time" << ToDateTime(Info.StateChangeTime);
+ }
Obj.BeginObject("process_metrics");
{
Obj << "MemoryBytes" << Info.Metrics.MemoryBytes;
@@ -98,6 +121,11 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
HttpVerb::kGet);
m_Router.RegisterRoute(
+ "deprovision",
+ [this](HttpRouterRequest& Req) { HandleDeprovisionAll(Req.ServerRequest()); },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
"modules/{moduleid}",
[this](HttpRouterRequest& Req) {
std::string_view ModuleId = Req.GetCapture(1);
@@ -229,15 +257,23 @@ HttpHubService::HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpSta
HttpVerb::kPost);
m_Router.RegisterRoute(
- "stats",
+ "proxy/{port}/{proxypath}",
[this](HttpRouterRequest& Req) {
- CbObjectWriter Obj;
- Obj << "currentInstanceCount" << m_Hub.GetInstanceCount();
- Obj << "maxInstanceCount" << m_Hub.GetMaxInstanceCount();
- Obj << "instanceLimit" << m_Hub.GetConfig().InstanceLimit;
- Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ std::string_view PortStr = Req.GetCapture(1);
+
+ // Use RelativeUriWithExtension to preserve the file extension that the
+ // router's URI parser strips (e.g. ".css", ".js") - the upstream server
+ // needs the full path including the extension.
+ std::string_view FullUri = Req.ServerRequest().RelativeUriWithExtension();
+ std::string_view Prefix = "proxy/";
+
+ // FullUri is "proxy/{port}/{path...}" - skip past "proxy/{port}/"
+ size_t PathStart = Prefix.size() + PortStr.size() + 1;
+ std::string_view PathTail = (PathStart < FullUri.size()) ? FullUri.substr(PathStart) : std::string_view{};
+
+ m_Proxy.HandleProxyRequest(Req.ServerRequest(), PortStr, PathTail);
},
- HttpVerb::kGet);
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
m_StatsService.RegisterHandler("hub", *this);
m_StatusService.RegisterHandler("hub", *this);
@@ -286,7 +322,37 @@ HttpHubService::HandleStatusRequest(HttpServerRequest& Request)
void
HttpHubService::HandleStatsRequest(HttpServerRequest& Request)
{
- Request.WriteResponse(HttpResponseCode::OK, CollectStats());
+ CbObjectWriter Cbo;
+
+ EmitSnapshot("requests", m_HttpRequests, Cbo);
+
+ Cbo << "currentInstanceCount" << m_Hub.GetInstanceCount();
+ Cbo << "maxInstanceCount" << m_Hub.GetMaxInstanceCount();
+ Cbo << "instanceLimit" << m_Hub.GetConfig().InstanceLimit;
+
+ SystemMetrics SysMetrics;
+ DiskSpace Disk;
+ m_Hub.GetMachineMetrics(SysMetrics, Disk);
+ Cbo.BeginObject("machine");
+ {
+ Cbo << "disk_free_bytes" << Disk.Free;
+ Cbo << "disk_total_bytes" << Disk.Total;
+ Cbo << "memory_avail_mib" << SysMetrics.AvailSystemMemoryMiB;
+ Cbo << "memory_total_mib" << SysMetrics.SystemMemoryMiB;
+ Cbo << "virtual_memory_avail_mib" << SysMetrics.AvailVirtualMemoryMiB;
+ Cbo << "virtual_memory_total_mib" << SysMetrics.VirtualMemoryMiB;
+ }
+ Cbo.EndObject();
+
+ const ResourceMetrics& Limits = m_Hub.GetConfig().ResourceLimits;
+ Cbo.BeginObject("resource_limits");
+ {
+ Cbo << "disk_bytes" << Limits.DiskUsageBytes;
+ Cbo << "memory_bytes" << Limits.MemoryUsageBytes;
+ }
+ Cbo.EndObject();
+
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
}
CbObject
@@ -310,6 +376,81 @@ HttpHubService::GetActivityCounter()
}
void
+HttpHubService::HandleDeprovisionAll(HttpServerRequest& Request)
+{
+ std::vector<std::string> ModulesToDeprovision;
+ m_Hub.EnumerateModules([&ModulesToDeprovision](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) {
+ if (InstanceInfo.State == HubInstanceState::Provisioned || InstanceInfo.State == HubInstanceState::Hibernated)
+ {
+ ModulesToDeprovision.push_back(std::string(ModuleId));
+ }
+ });
+
+ if (ModulesToDeprovision.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ std::vector<std::string> Rejected;
+ std::vector<std::string> Accepted;
+ std::vector<std::string> Completed;
+ for (const std::string& ModuleId : ModulesToDeprovision)
+ {
+ Hub::Response Response = m_Hub.Deprovision(ModuleId);
+ switch (Response.ResponseCode)
+ {
+ case Hub::EResponseCode::NotFound:
+ // Ignore
+ break;
+ case Hub::EResponseCode::Rejected:
+ Rejected.push_back(ModuleId);
+ break;
+ case Hub::EResponseCode::Accepted:
+ Accepted.push_back(ModuleId);
+ break;
+ case Hub::EResponseCode::Completed:
+ Completed.push_back(ModuleId);
+ break;
+ }
+ }
+ if (Rejected.empty() && Accepted.empty() && Completed.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ HttpResponseCode Response = HttpResponseCode::OK;
+ CbObjectWriter Writer;
+ if (!Completed.empty())
+ {
+ Writer.BeginArray("Completed");
+ for (const std::string& ModuleId : Completed)
+ {
+ Writer.AddString(ModuleId);
+ }
+ Writer.EndArray(); // Completed
+ }
+ if (!Accepted.empty())
+ {
+ Writer.BeginArray("Accepted");
+ for (const std::string& ModuleId : Accepted)
+ {
+ Writer.AddString(ModuleId);
+ }
+ Writer.EndArray(); // Accepted
+ Response = HttpResponseCode::Accepted;
+ }
+ if (!Rejected.empty())
+ {
+ Writer.BeginArray("Rejected");
+ for (const std::string& ModuleId : Rejected)
+ {
+ Writer.AddString(ModuleId);
+ }
+ Writer.EndArray(); // Rejected
+ Response = HttpResponseCode::Conflict;
+ }
+ Request.WriteResponse(Response, Writer.Save());
+}
+
+void
HttpHubService::HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId)
{
Hub::InstanceInfo InstanceInfo;
@@ -328,45 +469,36 @@ HttpHubService::HandleModuleGet(HttpServerRequest& Request, std::string_view Mod
void
HttpHubService::HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId)
{
- Hub::InstanceInfo InstanceInfo;
- if (!m_Hub.Find(ModuleId, &InstanceInfo))
+ Hub::Response Resp = m_Hub.Obliterate(std::string(ModuleId));
+
+ if (HandleFailureResults(Request, Resp))
{
- Request.WriteResponse(HttpResponseCode::NotFound);
return;
}
- if (InstanceInfo.State == HubInstanceState::Provisioned || InstanceInfo.State == HubInstanceState::Hibernated ||
- InstanceInfo.State == HubInstanceState::Crashed)
- {
- try
- {
- Hub::Response Resp = m_Hub.Deprovision(std::string(ModuleId));
-
- if (HandleFailureResults(Request, Resp))
- {
- return;
- }
-
- // TODO: nuke all related storage
+ const HttpResponseCode HttpCode =
+ (Resp.ResponseCode == Hub::EResponseCode::Accepted) ? HttpResponseCode::Accepted : HttpResponseCode::OK;
+ CbObjectWriter Obj;
+ Obj << "moduleId" << ModuleId;
+ Request.WriteResponse(HttpCode, Obj.Save());
+}
- const HttpResponseCode HttpCode =
- (Resp.ResponseCode == Hub::EResponseCode::Accepted) ? HttpResponseCode::Accepted : HttpResponseCode::OK;
- CbObjectWriter Obj;
- Obj << "moduleId" << ModuleId;
- return Request.WriteResponse(HttpCode, Obj.Save());
- }
- catch (const std::exception& Ex)
- {
- ZEN_ERROR("Exception while deprovisioning module '{}': {}", ModuleId, Ex.what());
- throw;
- }
- }
+void
+HttpHubService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
+{
+ m_Proxy.OnWebSocketOpen(std::move(Connection), RelativeUri);
+}
- // TODO: nuke all related storage
+void
+HttpHubService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg)
+{
+ m_Proxy.OnWebSocketMessage(Conn, Msg);
+}
- CbObjectWriter Obj;
- Obj << "moduleId" << ModuleId;
- Request.WriteResponse(HttpResponseCode::OK, Obj.Save());
+void
+HttpHubService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason)
+{
+ m_Proxy.OnWebSocketClose(Conn, Code, Reason);
}
} // namespace zen
diff --git a/src/zenserver/hub/httphubservice.h b/src/zenserver/hub/httphubservice.h
index 1bb1c303e..f4d1b0b89 100644
--- a/src/zenserver/hub/httphubservice.h
+++ b/src/zenserver/hub/httphubservice.h
@@ -2,11 +2,16 @@
#pragma once
+#include <zencore/thread.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/httpstatus.h>
+#include <zenhttp/websocket.h>
+
+#include <memory>
namespace zen {
+class HttpProxyHandler;
class HttpStatsService;
class Hub;
@@ -16,10 +21,10 @@ class Hub;
* use in UEFN content worker style scenarios.
*
*/
-class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider
+class HttpHubService : public HttpService, public IHttpStatusProvider, public IHttpStatsProvider, public IWebSocketHandler
{
public:
- HttpHubService(Hub& Hub, HttpStatsService& StatsService, HttpStatusService& StatusService);
+ HttpHubService(Hub& Hub, HttpProxyHandler& Proxy, HttpStatsService& StatsService, HttpStatusService& StatusService);
~HttpHubService();
HttpHubService(const HttpHubService&) = delete;
@@ -32,6 +37,11 @@ public:
virtual CbObject CollectStats() override;
virtual uint64_t GetActivityCounter() override;
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId);
private:
@@ -43,8 +53,11 @@ private:
HttpStatsService& m_StatsService;
HttpStatusService& m_StatusService;
+ void HandleDeprovisionAll(HttpServerRequest& Request);
void HandleModuleGet(HttpServerRequest& Request, std::string_view ModuleId);
void HandleModuleDelete(HttpServerRequest& Request, std::string_view ModuleId);
+
+ HttpProxyHandler& m_Proxy;
};
} // namespace zen
diff --git a/src/zenserver/hub/httpproxyhandler.cpp b/src/zenserver/hub/httpproxyhandler.cpp
new file mode 100644
index 000000000..235d7388f
--- /dev/null
+++ b/src/zenserver/hub/httpproxyhandler.cpp
@@ -0,0 +1,528 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpproxyhandler.h"
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpwsclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <charconv>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+#endif // ZEN_WITH_TESTS
+
+namespace zen {
+
+namespace {
+
+ std::string InjectProxyScript(std::string_view Html, uint16_t Port)
+ {
+ ExtendableStringBuilder<2048> Script;
+ Script.Append("<script>\n(function(){\n var P = \"/hub/proxy/");
+ Script.Append(fmt::format("{}", Port));
+ Script.Append(
+ "\";\n"
+ " var OF = window.fetch;\n"
+ " window.fetch = function(u, o) {\n"
+ " if (typeof u === \"string\") {\n"
+ " try {\n"
+ " var p = new URL(u, location.origin);\n"
+ " if (p.origin === location.origin && !p.pathname.startsWith(P))\n"
+ " { p.pathname = P + p.pathname; u = p.toString(); }\n"
+ " } catch(e) {\n"
+ " if (u.startsWith(\"/\") && !u.startsWith(P)) u = P + u;\n"
+ " }\n"
+ " }\n"
+ " return OF.call(this, u, o);\n"
+ " };\n"
+ " var OW = window.WebSocket;\n"
+ " window.WebSocket = function(u, pr) {\n"
+ " try {\n"
+ " var p = new URL(u);\n"
+ " if (p.hostname === location.hostname\n"
+ " && String(p.port || (p.protocol === \"wss:\" ? \"443\" : \"80\"))\n"
+ " === String(location.port || (location.protocol === \"https:\" ? \"443\" : \"80\"))\n"
+ " && !p.pathname.startsWith(P))\n"
+ " { p.pathname = P + p.pathname; u = p.toString(); }\n"
+ " } catch(e) {}\n"
+ " return pr !== undefined ? new OW(u, pr) : new OW(u);\n"
+ " };\n"
+ " window.WebSocket.prototype = OW.prototype;\n"
+ " window.WebSocket.CONNECTING = OW.CONNECTING;\n"
+ " window.WebSocket.OPEN = OW.OPEN;\n"
+ " window.WebSocket.CLOSING = OW.CLOSING;\n"
+ " window.WebSocket.CLOSED = OW.CLOSED;\n"
+ " var OO = window.open;\n"
+ " window.open = function(u, t, f) {\n"
+ " if (typeof u === \"string\") {\n"
+ " try {\n"
+ " var p = new URL(u, location.origin);\n"
+ " if (p.origin === location.origin && !p.pathname.startsWith(P))\n"
+ " { p.pathname = P + p.pathname; u = p.toString(); }\n"
+ " } catch(e) {}\n"
+ " }\n"
+ " return OO.call(this, u, t, f);\n"
+ " };\n"
+ " document.addEventListener(\"click\", function(e) {\n"
+ " var t = e.composedPath ? e.composedPath()[0] : e.target;\n"
+ " while (t && t.tagName !== \"A\") t = t.parentNode || t.host;\n"
+ " if (!t || !t.href) return;\n"
+ " try {\n"
+ " var h = new URL(t.href);\n"
+ " if (h.origin === location.origin && !h.pathname.startsWith(P))\n"
+ " { h.pathname = P + h.pathname; e.preventDefault(); window.location.href = h.toString(); }\n"
+ " } catch(x) {}\n"
+ " }, true);\n"
+ "})();\n</script>");
+
+ std::string ScriptStr = Script.ToString();
+
+ size_t HeadClose = Html.find("</head>");
+ if (HeadClose != std::string_view::npos)
+ {
+ std::string Result;
+ Result.reserve(Html.size() + ScriptStr.size());
+ Result.append(Html.substr(0, HeadClose));
+ Result.append(ScriptStr);
+ Result.append(Html.substr(HeadClose));
+ return Result;
+ }
+
+ std::string Result;
+ Result.reserve(Html.size() + ScriptStr.size());
+ Result.append(ScriptStr);
+ Result.append(Html);
+ return Result;
+ }
+
+} // namespace
+
+struct HttpProxyHandler::WsBridge : public RefCounted, public IWsClientHandler
+{
+ Ref<WebSocketConnection> ClientConn;
+ std::unique_ptr<HttpWsClient> UpstreamClient;
+ uint16_t Port = 0;
+
+ void OnWsOpen() override {}
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ if (!ClientConn->IsOpen())
+ {
+ return;
+ }
+ switch (Msg.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ ClientConn->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ case WebSocketOpcode::kBinary:
+ ClientConn->SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ default:
+ break;
+ }
+ }
+
+ void OnWsClose(uint16_t Code, std::string_view Reason) override
+ {
+ if (ClientConn->IsOpen())
+ {
+ ClientConn->Close(Code, Reason);
+ }
+ }
+};
+
+HttpProxyHandler::HttpProxyHandler()
+{
+}
+
+HttpProxyHandler::HttpProxyHandler(PortValidator ValidatePort) : m_ValidatePort(std::move(ValidatePort))
+{
+}
+
+void
+HttpProxyHandler::SetPortValidator(PortValidator ValidatePort)
+{
+ m_ValidatePort = std::move(ValidatePort);
+}
+
+HttpProxyHandler::~HttpProxyHandler()
+{
+ try
+ {
+ Shutdown();
+ }
+ catch (...)
+ {
+ }
+}
+
+HttpClient&
+HttpProxyHandler::GetOrCreateProxyClient(uint16_t Port)
+{
+ HttpClient* Result = nullptr;
+ m_ProxyClientsLock.WithExclusiveLock([&] {
+ auto It = m_ProxyClients.find(Port);
+ if (It == m_ProxyClients.end())
+ {
+ HttpClientSettings Settings;
+ Settings.LogCategory = "hub-proxy";
+ Settings.ConnectTimeout = std::chrono::milliseconds(5000);
+ Settings.Timeout = std::chrono::milliseconds(30000);
+ auto Client = std::make_unique<HttpClient>(fmt::format("http://127.0.0.1:{}", Port), Settings);
+ Result = Client.get();
+ m_ProxyClients.emplace(Port, std::move(Client));
+ }
+ else
+ {
+ Result = It->second.get();
+ }
+ });
+ return *Result;
+}
+
+void
+HttpProxyHandler::HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail)
+{
+ uint16_t Port = 0;
+ auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port);
+ if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size())
+ {
+ Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "invalid proxy URL");
+ return;
+ }
+
+ if (!m_ValidatePort(Port))
+ {
+ Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "target instance not available");
+ return;
+ }
+
+ HttpClient& Client = GetOrCreateProxyClient(Port);
+
+ std::string RequestPath;
+ RequestPath.reserve(1 + PathTail.size());
+ RequestPath.push_back('/');
+ RequestPath.append(PathTail);
+
+ std::string_view QueryString = Request.QueryString();
+ if (!QueryString.empty())
+ {
+ RequestPath.push_back('?');
+ RequestPath.append(QueryString);
+ }
+
+ HttpClient::KeyValueMap ForwardHeaders;
+ HttpContentType AcceptType = Request.AcceptContentType();
+ if (AcceptType != HttpContentType::kUnknownContentType)
+ {
+ ForwardHeaders->emplace("Accept", std::string(MapContentTypeToString(AcceptType)));
+ }
+
+ std::string_view Auth = Request.GetAuthorizationHeader();
+ if (!Auth.empty())
+ {
+ ForwardHeaders->emplace("Authorization", std::string(Auth));
+ }
+
+ HttpContentType ReqContentType = Request.RequestContentType();
+ if (ReqContentType != HttpContentType::kUnknownContentType)
+ {
+ ForwardHeaders->emplace("Content-Type", std::string(MapContentTypeToString(ReqContentType)));
+ }
+
+ HttpClient::Response Response;
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ Response = Client.Get(RequestPath, ForwardHeaders);
+ break;
+ case HttpVerb::kPost:
+ {
+ IoBuffer Payload = Request.ReadPayload();
+ Response = Client.Post(RequestPath, Payload, ForwardHeaders);
+ break;
+ }
+ case HttpVerb::kPut:
+ {
+ IoBuffer Payload = Request.ReadPayload();
+ Response = Client.Put(RequestPath, Payload, ForwardHeaders);
+ break;
+ }
+ case HttpVerb::kDelete:
+ Response = Client.Delete(RequestPath, ForwardHeaders);
+ break;
+ case HttpVerb::kHead:
+ Response = Client.Head(RequestPath, ForwardHeaders);
+ break;
+ default:
+ Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "method not supported");
+ return;
+ }
+
+ if (Response.Error)
+ {
+ if (!m_ValidatePort(Port))
+ {
+ Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "target instance not available");
+ return;
+ }
+
+ ZEN_WARN("proxy request to port {} failed: {}", Port, Response.Error->ErrorMessage);
+ switch (Response.Error->ErrorCode)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ case HttpClientErrorCode::kHostResolutionFailure:
+ return Request.WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("instance not reachable: {}", Response.Error->ErrorMessage));
+ case HttpClientErrorCode::kOperationTimedOut:
+ return Request.WriteResponse(HttpResponseCode::GatewayTimeout,
+ HttpContentType::kText,
+ fmt::format("upstream request timed out: {}", Response.Error->ErrorMessage));
+ case HttpClientErrorCode::kRequestCancelled:
+ return Request.WriteResponse(HttpResponseCode::ServiceUnavailable,
+ HttpContentType::kText,
+ fmt::format("upstream request cancelled: {}", Response.Error->ErrorMessage));
+ default:
+ return Request.WriteResponse(HttpResponseCode::BadGateway,
+ HttpContentType::kText,
+ fmt::format("upstream request failed: {}", Response.Error->ErrorMessage));
+ }
+ }
+
+ HttpContentType ContentType = Response.ResponsePayload.GetContentType();
+
+ if (ContentType == HttpContentType::kHTML)
+ {
+ std::string_view Html(static_cast<const char*>(Response.ResponsePayload.GetData()), Response.ResponsePayload.GetSize());
+ std::string Injected = InjectProxyScript(Html, Port);
+ Request.WriteResponse(Response.StatusCode, HttpContentType::kHTML, std::string_view(Injected));
+ }
+ else
+ {
+ Request.WriteResponse(Response.StatusCode, ContentType, std::move(Response.ResponsePayload));
+ }
+}
+
+void
+HttpProxyHandler::PrunePort(uint16_t Port)
+{
+ m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.erase(Port); });
+
+ std::vector<Ref<WsBridge>> Stale;
+ m_WsBridgesLock.WithExclusiveLock([&] {
+ for (auto It = m_WsBridges.begin(); It != m_WsBridges.end();)
+ {
+ if (It->second->Port == Port)
+ {
+ Stale.push_back(std::move(It->second));
+ It = m_WsBridges.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+ });
+
+ for (auto& Bridge : Stale)
+ {
+ if (Bridge->UpstreamClient)
+ {
+ Bridge->UpstreamClient->Close(1001, "instance shutting down");
+ }
+ if (Bridge->ClientConn->IsOpen())
+ {
+ Bridge->ClientConn->Close(1001, "instance shutting down");
+ }
+ }
+}
+
+void
+HttpProxyHandler::Shutdown()
+{
+ m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.clear(); });
+ m_ProxyClientsLock.WithExclusiveLock([&] { m_ProxyClients.clear(); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// WebSocket proxy
+//
+
+void
+HttpProxyHandler::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
+{
+ const std::string_view ProxyPrefix = "proxy/";
+ if (!RelativeUri.starts_with(ProxyPrefix))
+ {
+ Connection->Close(1008, "unsupported WebSocket endpoint");
+ return;
+ }
+
+ std::string_view ProxyTail = RelativeUri.substr(ProxyPrefix.size());
+
+ size_t SlashPos = ProxyTail.find('/');
+ std::string_view PortStr = (SlashPos != std::string_view::npos) ? ProxyTail.substr(0, SlashPos) : ProxyTail;
+ std::string_view Path = (SlashPos != std::string_view::npos) ? ProxyTail.substr(SlashPos) : "/";
+
+ uint16_t Port = 0;
+ auto [Ptr, Ec] = std::from_chars(PortStr.data(), PortStr.data() + PortStr.size(), Port);
+ if (Ec != std::errc{} || Ptr != PortStr.data() + PortStr.size())
+ {
+ Connection->Close(1008, "invalid proxy URL");
+ return;
+ }
+
+ if (!m_ValidatePort(Port))
+ {
+ Connection->Close(1008, "target instance not available");
+ return;
+ }
+
+ std::string WsUrl = HttpToWsUrl(fmt::format("http://127.0.0.1:{}", Port), Path);
+
+ Ref<WsBridge> Bridge(new WsBridge());
+ Bridge->ClientConn = Connection;
+ Bridge->Port = Port;
+
+ Bridge->UpstreamClient = std::make_unique<HttpWsClient>(WsUrl, *Bridge);
+
+ try
+ {
+ Bridge->UpstreamClient->Connect();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("proxy WebSocket connect to {} failed: {}", WsUrl, Ex.what());
+ Connection->Close(1011, "upstream connect failed");
+ return;
+ }
+
+ WebSocketConnection* Key = Connection.Get();
+ m_WsBridgesLock.WithExclusiveLock([&] { m_WsBridges.emplace(Key, std::move(Bridge)); });
+}
+
+void
+HttpProxyHandler::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg)
+{
+ Ref<WsBridge> Bridge;
+ m_WsBridgesLock.WithSharedLock([&] {
+ auto It = m_WsBridges.find(&Conn);
+ if (It != m_WsBridges.end())
+ {
+ Bridge = It->second;
+ }
+ });
+
+ if (!Bridge || !Bridge->UpstreamClient)
+ {
+ return;
+ }
+
+ switch (Msg.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ Bridge->UpstreamClient->SendText(std::string_view(static_cast<const char*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ case WebSocketOpcode::kBinary:
+ Bridge->UpstreamClient->SendBinary(
+ std::span<const uint8_t>(static_cast<const uint8_t*>(Msg.Payload.GetData()), Msg.Payload.GetSize()));
+ break;
+ case WebSocketOpcode::kClose:
+ Bridge->UpstreamClient->Close(Msg.CloseCode, {});
+ break;
+ default:
+ break;
+ }
+}
+
+void
+HttpProxyHandler::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason)
+{
+ Ref<WsBridge> Bridge = m_WsBridgesLock.WithExclusiveLock([this, &Conn]() -> Ref<WsBridge> {
+ auto It = m_WsBridges.find(&Conn);
+ if (It != m_WsBridges.end())
+ {
+ Ref<WsBridge> Bridge = std::move(It->second);
+ m_WsBridges.erase(It);
+ return Bridge;
+ }
+ return {};
+ });
+
+ if (Bridge && Bridge->UpstreamClient)
+ {
+ Bridge->UpstreamClient->Close(Code, Reason);
+ }
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("server.httpproxyhandler");
+
+TEST_CASE("server.httpproxyhandler.html_injection")
+{
+ SUBCASE("injects before </head>")
+ {
+ std::string Result = InjectProxyScript("<html><head></head><body></body></html>", 21005);
+ CHECK(Result.find("<script>") != std::string::npos);
+ CHECK(Result.find("/hub/proxy/21005") != std::string::npos);
+ size_t ScriptEnd = Result.find("</script>");
+ size_t HeadClose = Result.find("</head>");
+ REQUIRE(ScriptEnd != std::string::npos);
+ REQUIRE(HeadClose != std::string::npos);
+ CHECK(ScriptEnd < HeadClose);
+ }
+
+ SUBCASE("prepends when no </head>")
+ {
+ std::string Result = InjectProxyScript("<body>content</body>", 21005);
+ CHECK(Result.find("<script>") == 0);
+ CHECK(Result.find("<body>content</body>") != std::string::npos);
+ }
+
+ SUBCASE("empty html")
+ {
+ std::string Result = InjectProxyScript("", 21005);
+ CHECK(Result.find("<script>") != std::string::npos);
+ CHECK(Result.find("/hub/proxy/21005") != std::string::npos);
+ }
+
+ SUBCASE("preserves original content")
+ {
+ std::string_view Html = "<html><head><title>Test</title></head><body><h1>Dashboard</h1></body></html>";
+ std::string Result = InjectProxyScript(Html, 21005);
+ CHECK(Result.find("<title>Test</title>") != std::string::npos);
+ CHECK(Result.find("<h1>Dashboard</h1>") != std::string::npos);
+ }
+}
+
+TEST_CASE("server.httpproxyhandler.port_embedding")
+{
+ std::string Result = InjectProxyScript("<head></head>", 80);
+ CHECK(Result.find("/hub/proxy/80") != std::string::npos);
+
+ Result = InjectProxyScript("<head></head>", 65535);
+ CHECK(Result.find("/hub/proxy/65535") != std::string::npos);
+}
+
+TEST_SUITE_END();
+
+void
+httpproxyhandler_forcelink()
+{
+}
+#endif // ZEN_WITH_TESTS
+
+} // namespace zen
diff --git a/src/zenserver/hub/httpproxyhandler.h b/src/zenserver/hub/httpproxyhandler.h
new file mode 100644
index 000000000..8667c0ca1
--- /dev/null
+++ b/src/zenserver/hub/httpproxyhandler.h
@@ -0,0 +1,52 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+
+#include <functional>
+#include <memory>
+#include <unordered_map>
+
+namespace zen {
+
+class HttpClient;
+
+class HttpProxyHandler
+{
+public:
+ using PortValidator = std::function<bool(uint16_t)>;
+
+ HttpProxyHandler();
+ explicit HttpProxyHandler(PortValidator ValidatePort);
+ ~HttpProxyHandler();
+
+ void SetPortValidator(PortValidator ValidatePort);
+
+ HttpProxyHandler(const HttpProxyHandler&) = delete;
+ HttpProxyHandler& operator=(const HttpProxyHandler&) = delete;
+
+ void HandleProxyRequest(HttpServerRequest& Request, std::string_view PortStr, std::string_view PathTail);
+ void PrunePort(uint16_t Port);
+ void Shutdown();
+
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri);
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg);
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason);
+
+private:
+ PortValidator m_ValidatePort;
+
+ HttpClient& GetOrCreateProxyClient(uint16_t Port);
+
+ RwLock m_ProxyClientsLock;
+ std::unordered_map<uint16_t, std::unique_ptr<HttpClient>> m_ProxyClients;
+
+ struct WsBridge;
+ RwLock m_WsBridgesLock;
+ std::unordered_map<WebSocketConnection*, Ref<WsBridge>> m_WsBridges;
+};
+
+} // namespace zen
diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp
index 6c44e2333..128d3ed35 100644
--- a/src/zenserver/hub/hub.cpp
+++ b/src/zenserver/hub/hub.cpp
@@ -19,7 +19,6 @@ ZEN_THIRD_PARTY_INCLUDES_START
ZEN_THIRD_PARTY_INCLUDES_END
#if ZEN_WITH_TESTS
-# include <zencore/filesystem.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
#endif
@@ -122,23 +121,73 @@ private:
//////////////////////////////////////////////////////////////////////////
-Hub::Hub(const Configuration& Config,
- ZenServerEnvironment&& RunEnvironment,
- WorkerThreadPool* OptionalWorkerPool,
- AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback)
+ProcessMetrics
+Hub::AtomicProcessMetrics::Load() const
+{
+ return {
+ .MemoryBytes = MemoryBytes.load(),
+ .KernelTimeMs = KernelTimeMs.load(),
+ .UserTimeMs = UserTimeMs.load(),
+ .WorkingSetSize = WorkingSetSize.load(),
+ .PeakWorkingSetSize = PeakWorkingSetSize.load(),
+ .PagefileUsage = PagefileUsage.load(),
+ .PeakPagefileUsage = PeakPagefileUsage.load(),
+ };
+}
+
+void
+Hub::AtomicProcessMetrics::Store(const ProcessMetrics& Metrics)
+{
+ MemoryBytes.store(Metrics.MemoryBytes);
+ KernelTimeMs.store(Metrics.KernelTimeMs);
+ UserTimeMs.store(Metrics.UserTimeMs);
+ WorkingSetSize.store(Metrics.WorkingSetSize);
+ PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize);
+ PagefileUsage.store(Metrics.PagefileUsage);
+ PeakPagefileUsage.store(Metrics.PeakPagefileUsage);
+}
+
+void
+Hub::AtomicProcessMetrics::Reset()
+{
+ MemoryBytes.store(0);
+ KernelTimeMs.store(0);
+ UserTimeMs.store(0);
+ WorkingSetSize.store(0);
+ PeakWorkingSetSize.store(0);
+ PagefileUsage.store(0);
+ PeakPagefileUsage.store(0);
+}
+
+void
+Hub::GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const
+{
+ m_Lock.WithSharedLock([&]() {
+ OutSystemMetrict = m_SystemMetrics;
+ OutDiskSpace = m_DiskSpace;
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+Hub::Hub(const Configuration& Config, ZenServerEnvironment&& RunEnvironment, AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback)
: m_Config(Config)
, m_RunEnvironment(std::move(RunEnvironment))
-, m_WorkerPool(OptionalWorkerPool)
+, m_WorkerPool(Config.OptionalProvisionWorkerPool)
, m_BackgroundWorkLatch(1)
, m_ModuleStateChangeCallback(std::move(ModuleStateChangeCallback))
, m_ActiveInstances(Config.InstanceLimit)
, m_FreeActiveInstanceIndexes(Config.InstanceLimit)
{
- m_HostMetrics = GetSystemMetrics();
- m_ResourceLimits.DiskUsageBytes = 1000ull * 1024 * 1024 * 1024;
- m_ResourceLimits.MemoryUsageBytes = 16ull * 1024 * 1024 * 1024;
+ ZEN_ASSERT_FORMAT(
+ Config.OptionalProvisionWorkerPool != Config.OptionalHydrationWorkerPool || Config.OptionalProvisionWorkerPool == nullptr,
+ "Provision and hydration worker pools must be distinct to avoid deadlocks");
- if (m_Config.HydrationTargetSpecification.empty())
+ if (!m_Config.HydrationTargetSpecification.empty())
+ {
+ m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification;
+ }
+ else if (!m_Config.HydrationOptions)
{
std::filesystem::path FileHydrationPath = m_RunEnvironment.CreateChildDir("hydration_storage");
ZEN_INFO("using file hydration path: '{}'", FileHydrationPath);
@@ -146,7 +195,7 @@ Hub::Hub(const Configuration& Config,
}
else
{
- m_HydrationTargetSpecification = m_Config.HydrationTargetSpecification;
+ m_HydrationOptions = m_Config.HydrationOptions;
}
m_HydrationTempPath = m_RunEnvironment.CreateChildDir("hydration_temp");
@@ -171,6 +220,9 @@ Hub::Hub(const Configuration& Config,
}
}
#endif
+
+ UpdateMachineMetrics();
+
m_WatchDog = std::thread([this]() { WatchDog(); });
}
@@ -195,6 +247,9 @@ Hub::Shutdown()
{
ZEN_INFO("Hub service shutting down, deprovisioning any current instances");
+ bool Expected = false;
+ bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true);
+
m_WatchDogEvent.Set();
if (m_WatchDog.joinable())
{
@@ -203,8 +258,6 @@ Hub::Shutdown()
m_WatchDog = {};
- bool Expected = false;
- bool WaitForBackgroundWork = m_ShutdownFlag.compare_exchange_strong(Expected, true);
if (WaitForBackgroundWork && m_WorkerPool)
{
m_BackgroundWorkLatch.CountDown();
@@ -254,7 +307,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
if (auto It = m_InstanceLookup.find(std::string(ModuleId)); It == m_InstanceLookup.end())
{
std::string Reason;
- if (!CanProvisionInstance(ModuleId, /* out */ Reason))
+ if (!CanProvisionInstanceLocked(ModuleId, /* out */ Reason))
{
ZEN_WARN("Cannot provision new storage server instance for module '{}': {}", ModuleId, Reason);
@@ -272,11 +325,18 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
auto NewInstance = std::make_unique<StorageServerInstance>(
m_RunEnvironment,
StorageServerInstance::Configuration{.BasePort = GetInstanceIndexAssignedPort(ActiveInstanceIndex),
- .HydrationTempPath = m_HydrationTempPath,
+ .StateDir = m_RunEnvironment.CreateChildDir(ModuleId),
+ .TempDir = m_HydrationTempPath / ModuleId,
.HydrationTargetSpecification = m_HydrationTargetSpecification,
+ .HydrationOptions = m_HydrationOptions,
.HttpThreadCount = m_Config.InstanceHttpThreadCount,
.CoreLimit = m_Config.InstanceCoreLimit,
- .ConfigPath = m_Config.InstanceConfigPath},
+ .ConfigPath = m_Config.InstanceConfigPath,
+ .Malloc = m_Config.InstanceMalloc,
+ .Trace = m_Config.InstanceTrace,
+ .TraceHost = m_Config.InstanceTraceHost,
+ .TraceFile = m_Config.InstanceTraceFile,
+ .OptionalWorkerPool = m_Config.OptionalHydrationWorkerPool},
ModuleId);
#if ZEN_PLATFORM_WINDOWS
@@ -289,6 +349,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
Instance = NewInstance->LockExclusive(/*Wait*/ true);
m_ActiveInstances[ActiveInstanceIndex].Instance = std::move(NewInstance);
+ m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Reset();
m_InstanceLookup.insert_or_assign(std::string(ModuleId), ActiveInstanceIndex);
// Set Provisioning while both hub lock and instance lock are held so that any
// concurrent Deprovision sees the in-flight state, not Unprovisioned.
@@ -334,6 +395,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
case HubInstanceState::Unprovisioned:
break;
case HubInstanceState::Provisioned:
+ m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now());
return Response{EResponseCode::Completed};
case HubInstanceState::Hibernated:
_.ReleaseNow();
@@ -354,6 +416,7 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo)
Instance = {};
if (ActualState == HubInstanceState::Provisioned)
{
+ m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now());
return Response{EResponseCode::Completed};
}
if (ActualState == HubInstanceState::Provisioning)
@@ -540,6 +603,7 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI
switch (CurrentState)
{
case HubInstanceState::Deprovisioning:
+ case HubInstanceState::Obliterating:
return Response{EResponseCode::Accepted};
case HubInstanceState::Crashed:
case HubInstanceState::Hibernated:
@@ -585,11 +649,11 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI
try
{
m_WorkerPool->ScheduleWork(
- [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr)]() mutable {
+ [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr), OldState]() mutable {
auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); });
try
{
- CompleteDeprovision(*Instance, ActiveInstanceIndex);
+ CompleteDeprovision(*Instance, ActiveInstanceIndex, OldState);
}
catch (const std::exception& Ex)
{
@@ -617,20 +681,222 @@ Hub::InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveI
}
else
{
- CompleteDeprovision(Instance, ActiveInstanceIndex);
+ CompleteDeprovision(Instance, ActiveInstanceIndex, OldState);
+ }
+
+ return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed};
+}
+
+Hub::Response
+Hub::Obliterate(const std::string& ModuleId)
+{
+ ZEN_ASSERT(!m_ShutdownFlag.load());
+
+ StorageServerInstance::ExclusiveLockedPtr Instance;
+ size_t ActiveInstanceIndex = (size_t)-1;
+ {
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+
+ if (auto It = m_InstanceLookup.find(ModuleId); It != m_InstanceLookup.end())
+ {
+ ActiveInstanceIndex = It->second;
+ ZEN_ASSERT(ActiveInstanceIndex < m_ActiveInstances.size());
+
+ HubInstanceState CurrentState = m_ActiveInstances[ActiveInstanceIndex].State.load();
+
+ switch (CurrentState)
+ {
+ case HubInstanceState::Obliterating:
+ return Response{EResponseCode::Accepted};
+ case HubInstanceState::Provisioned:
+ case HubInstanceState::Hibernated:
+ case HubInstanceState::Crashed:
+ break;
+ case HubInstanceState::Deprovisioning:
+ return Response{EResponseCode::Rejected,
+ fmt::format("Module '{}' is being deprovisioned, retry after completion", ModuleId)};
+ case HubInstanceState::Recovering:
+ return Response{EResponseCode::Rejected, fmt::format("Module '{}' is currently recovering from a crash", ModuleId)};
+ case HubInstanceState::Unprovisioned:
+ return Response{EResponseCode::Completed};
+ default:
+ return Response{EResponseCode::Rejected,
+ fmt::format("Module '{}' is currently in state '{}'", ModuleId, ToString(CurrentState))};
+ }
+
+ std::unique_ptr<StorageServerInstance>& RawInstance = m_ActiveInstances[ActiveInstanceIndex].Instance;
+ ZEN_ASSERT(RawInstance != nullptr);
+
+ Instance = RawInstance->LockExclusive(/*Wait*/ true);
+ }
+ else
+ {
+ // Module not tracked by hub - obliterate backend data directly.
+ // Covers the deprovisioned case where data was preserved via dehydration.
+ if (m_ObliteratingInstances.contains(ModuleId))
+ {
+ return Response{EResponseCode::Accepted};
+ }
+
+ m_ObliteratingInstances.insert(ModuleId);
+ Lock.ReleaseNow();
+
+ if (m_WorkerPool)
+ {
+ m_BackgroundWorkLatch.AddCount(1);
+ try
+ {
+ m_WorkerPool->ScheduleWork(
+ [this, ModuleId = std::string(ModuleId)]() {
+ auto Guard = MakeGuard([this, ModuleId]() {
+ m_Lock.WithExclusiveLock([this, ModuleId]() { m_ObliteratingInstances.erase(ModuleId); });
+ m_BackgroundWorkLatch.CountDown();
+ });
+ try
+ {
+ ObliterateBackendData(ModuleId);
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("Failed async obliterate of untracked module '{}': {}", ModuleId, Ex.what());
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+ }
+ catch (const std::exception& DispatchEx)
+ {
+ ZEN_ERROR("Failed to dispatch async obliterate of untracked module '{}': {}", ModuleId, DispatchEx.what());
+ m_BackgroundWorkLatch.CountDown();
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_ObliteratingInstances.erase(ModuleId);
+ }
+ throw;
+ }
+
+ return Response{EResponseCode::Accepted};
+ }
+
+ auto _ = MakeGuard([this, &ModuleId]() {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_ObliteratingInstances.erase(ModuleId);
+ });
+
+ ObliterateBackendData(ModuleId);
+
+ return Response{EResponseCode::Completed};
+ }
+ }
+
+ HubInstanceState OldState = UpdateInstanceState(Instance, ActiveInstanceIndex, HubInstanceState::Obliterating);
+ const uint16_t Port = Instance.GetBasePort();
+ NotifyStateUpdate(ModuleId, OldState, HubInstanceState::Obliterating, Port, {});
+
+ if (m_WorkerPool)
+ {
+ std::shared_ptr<StorageServerInstance::ExclusiveLockedPtr> SharedInstancePtr =
+ std::make_shared<StorageServerInstance::ExclusiveLockedPtr>(std::move(Instance));
+
+ m_BackgroundWorkLatch.AddCount(1);
+ try
+ {
+ m_WorkerPool->ScheduleWork(
+ [this, ModuleId = std::string(ModuleId), ActiveInstanceIndex, Instance = std::move(SharedInstancePtr)]() mutable {
+ auto _ = MakeGuard([this]() { m_BackgroundWorkLatch.CountDown(); });
+ try
+ {
+ CompleteObliterate(*Instance, ActiveInstanceIndex);
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("Failed async obliterate of module '{}': {}", ModuleId, Ex.what());
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+ }
+ catch (const std::exception& DispatchEx)
+ {
+ ZEN_ERROR("Failed async dispatch obliterate of module '{}': {}", ModuleId, DispatchEx.what());
+ m_BackgroundWorkLatch.CountDown();
+
+ NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, OldState, Port, {});
+ {
+ RwLock::ExclusiveLockScope HubLock(m_Lock);
+ ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId)) != m_InstanceLookup.end());
+ ZEN_ASSERT_SLOW(m_InstanceLookup.find(std::string(ModuleId))->second == ActiveInstanceIndex);
+ UpdateInstanceState(HubLock, ActiveInstanceIndex, OldState);
+ }
+
+ throw;
+ }
+ }
+ else
+ {
+ CompleteObliterate(Instance, ActiveInstanceIndex);
}
return Response{m_WorkerPool ? EResponseCode::Accepted : EResponseCode::Completed};
}
void
-Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex)
+Hub::CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex)
{
const std::string ModuleId(Instance.GetModuleId());
const uint16_t Port = Instance.GetBasePort();
try
{
+ Instance.Obliterate();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("Failed to obliterate storage server instance for module '{}': {}", ModuleId, Ex.what());
+ Instance = {};
+ {
+ RwLock::ExclusiveLockScope HubLock(m_Lock);
+ UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Crashed);
+ }
+ NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Crashed, Port, {});
+ throw;
+ }
+
+ NotifyStateUpdate(ModuleId, HubInstanceState::Obliterating, HubInstanceState::Unprovisioned, Port, {});
+ RemoveInstance(Instance, ActiveInstanceIndex, ModuleId);
+}
+
+void
+Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState)
+{
+ const std::string ModuleId(Instance.GetModuleId());
+ const uint16_t Port = Instance.GetBasePort();
+
+ try
+ {
+ if (OldState == HubInstanceState::Provisioned)
+ {
+ ZEN_INFO("Triggering GC for module {}", ModuleId);
+
+ HttpClient GcClient(fmt::format("http://localhost:{}", Port));
+
+ HttpClient::KeyValueMap Params;
+ Params.Entries.insert({"smallobjects", "true"});
+ Params.Entries.insert({"skipcid", "false"});
+ HttpClient::Response Response = GcClient.Post("/admin/gc", HttpClient::Accept(HttpContentType::kCbObject), Params);
+ Stopwatch Timer;
+ while (Response && Timer.GetElapsedTimeMs() < 5000)
+ {
+ Response = GcClient.Get("/admin/gc", HttpClient::Accept(HttpContentType::kCbObject));
+ if (Response)
+ {
+ bool Complete = Response.AsObject()["Status"].AsString() != "Running";
+ if (Complete)
+ {
+ break;
+ }
+ Sleep(50);
+ }
+ }
+ }
Instance.Deprovision();
}
catch (const std::exception& Ex)
@@ -649,20 +915,7 @@ Hub::CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, si
}
NotifyStateUpdate(ModuleId, HubInstanceState::Deprovisioning, HubInstanceState::Unprovisioned, Port, {});
- Instance = {};
-
- std::unique_ptr<StorageServerInstance> DeleteInstance;
- {
- RwLock::ExclusiveLockScope HubLock(m_Lock);
- auto It = m_InstanceLookup.find(std::string(ModuleId));
- ZEN_ASSERT_SLOW(It != m_InstanceLookup.end());
- ZEN_ASSERT_SLOW(It->second == ActiveInstanceIndex);
- DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance);
- m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex);
- m_InstanceLookup.erase(It);
- UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned);
- }
- DeleteInstance.reset();
+ RemoveInstance(Instance, ActiveInstanceIndex, ModuleId);
}
Hub::Response
@@ -935,6 +1188,50 @@ Hub::CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t Ac
}
}
+void
+Hub::RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, std::string_view ModuleId)
+{
+ Instance = {};
+
+ std::unique_ptr<StorageServerInstance> DeleteInstance;
+ {
+ RwLock::ExclusiveLockScope HubLock(m_Lock);
+ auto It = m_InstanceLookup.find(std::string(ModuleId));
+ ZEN_ASSERT_SLOW(It != m_InstanceLookup.end());
+ ZEN_ASSERT_SLOW(It->second == ActiveInstanceIndex);
+ DeleteInstance = std::move(m_ActiveInstances[ActiveInstanceIndex].Instance);
+ m_FreeActiveInstanceIndexes.push_back(ActiveInstanceIndex);
+ m_InstanceLookup.erase(It);
+ UpdateInstanceState(HubLock, ActiveInstanceIndex, HubInstanceState::Unprovisioned);
+ }
+ DeleteInstance.reset();
+}
+
+void
+Hub::ObliterateBackendData(std::string_view ModuleId)
+{
+ std::filesystem::path ServerStateDir = m_RunEnvironment.GetChildBaseDir() / ModuleId;
+ std::filesystem::path TempDir = m_HydrationTempPath / ModuleId;
+
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
+
+ HydrationConfig Config{.ServerStateDir = ServerStateDir,
+ .TempDir = TempDir,
+ .ModuleId = std::string(ModuleId),
+ .TargetSpecification = m_HydrationTargetSpecification,
+ .Options = m_HydrationOptions};
+ if (m_Config.OptionalHydrationWorkerPool)
+ {
+ Config.Threading.emplace(HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalHydrationWorkerPool,
+ .AbortFlag = &AbortFlag,
+ .PauseFlag = &PauseFlag});
+ }
+
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Obliterate();
+}
+
bool
Hub::Find(std::string_view ModuleId, InstanceInfo* OutInstanceInfo)
{
@@ -947,12 +1244,10 @@ Hub::Find(std::string_view ModuleId, InstanceInfo* OutInstanceInfo)
ZEN_ASSERT(ActiveInstanceIndex < m_ActiveInstances.size());
const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance;
ZEN_ASSERT(Instance);
- InstanceInfo Info{
- m_ActiveInstances[ActiveInstanceIndex].State.load(),
- std::chrono::system_clock::now() // TODO
- };
- Instance->GetProcessMetrics(Info.Metrics);
- Info.Port = Instance->GetBasePort();
+ InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(),
+ m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()};
+ Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load();
+ Info.Port = Instance->GetBasePort();
*OutInstanceInfo = Info;
}
@@ -971,12 +1266,10 @@ Hub::EnumerateModules(std::function<void(std::string_view ModuleId, const Instan
{
const std::unique_ptr<StorageServerInstance>& Instance = m_ActiveInstances[ActiveInstanceIndex].Instance;
ZEN_ASSERT(Instance);
- InstanceInfo Info{
- m_ActiveInstances[ActiveInstanceIndex].State.load(),
- std::chrono::system_clock::now() // TODO
- };
- Instance->GetProcessMetrics(Info.Metrics);
- Info.Port = Instance->GetBasePort();
+ InstanceInfo Info{m_ActiveInstances[ActiveInstanceIndex].State.load(),
+ m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.load()};
+ Info.Metrics = m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Load();
+ Info.Port = Instance->GetBasePort();
Infos.push_back(std::make_pair(std::string(Instance->GetModuleId()), Info));
}
@@ -994,30 +1287,15 @@ Hub::GetInstanceCount()
return m_Lock.WithSharedLock([this]() { return gsl::narrow_cast<int>(m_InstanceLookup.size()); });
}
-void
-Hub::UpdateCapacityMetrics()
-{
- m_HostMetrics = GetSystemMetrics();
-
- // TODO: Should probably go into WatchDog and use atomic for update so it can be read without locks...
- // Per-instance stats are already refreshed by WatchDog and are readable via the Find and EnumerateModules
-}
-
-void
-Hub::UpdateStats()
-{
- int CurrentInstanceCount = m_Lock.WithSharedLock([this] { return gsl::narrow_cast<int>(m_InstanceLookup.size()); });
- int CurrentMaxCount = m_MaxInstanceCount.load();
-
- int NewMax = Max(CurrentMaxCount, CurrentInstanceCount);
-
- m_MaxInstanceCount.compare_exchange_weak(CurrentMaxCount, NewMax);
-}
-
bool
-Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason)
+Hub::CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason)
{
- ZEN_UNUSED(ModuleId);
+ if (m_ObliteratingInstances.contains(std::string(ModuleId)))
+ {
+ OutReason = fmt::format("module '{}' is being obliterated", ModuleId);
+ return false;
+ }
+
if (m_FreeActiveInstanceIndexes.empty())
{
OutReason = fmt::format("instance limit ({}) exceeded", m_Config.InstanceLimit);
@@ -1025,7 +1303,24 @@ Hub::CanProvisionInstance(std::string_view ModuleId, std::string& OutReason)
return false;
}
- // TODO: handle additional resource metrics
+ const uint64_t DiskUsedBytes = m_DiskSpace.Free <= m_DiskSpace.Total ? m_DiskSpace.Total - m_DiskSpace.Free : 0;
+ if (m_Config.ResourceLimits.DiskUsageBytes > 0 && DiskUsedBytes > m_Config.ResourceLimits.DiskUsageBytes)
+ {
+ OutReason =
+ fmt::format("disk usage ({}) exceeds ({})", NiceBytes(DiskUsedBytes), NiceBytes(m_Config.ResourceLimits.DiskUsageBytes));
+ return false;
+ }
+
+ const uint64_t RamUsedMiB = m_SystemMetrics.AvailSystemMemoryMiB <= m_SystemMetrics.SystemMemoryMiB
+ ? m_SystemMetrics.SystemMemoryMiB - m_SystemMetrics.AvailSystemMemoryMiB
+ : 0;
+ const uint64_t RamUsedBytes = RamUsedMiB * 1024 * 1024;
+ if (m_Config.ResourceLimits.MemoryUsageBytes > 0 && RamUsedBytes > m_Config.ResourceLimits.MemoryUsageBytes)
+ {
+ OutReason =
+ fmt::format("ram usage ({}) exceeds ({})", NiceBytes(RamUsedBytes), NiceBytes(m_Config.ResourceLimits.MemoryUsageBytes));
+ return false;
+ }
return true;
}
@@ -1036,6 +1331,21 @@ Hub::GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const
return gsl::narrow<uint16_t>(m_Config.BasePortNumber + ActiveInstanceIndex);
}
+bool
+Hub::IsInstancePort(uint16_t Port) const
+{
+ if (Port < m_Config.BasePortNumber)
+ {
+ return false;
+ }
+ size_t Index = Port - m_Config.BasePortNumber;
+ if (Index >= m_ActiveInstances.size())
+ {
+ return false;
+ }
+ return m_ActiveInstances[Index].State.load(std::memory_order_relaxed) != HubInstanceState::Unprovisioned;
+}
+
HubInstanceState
Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState)
{
@@ -1046,11 +1356,13 @@ Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewS
case HubInstanceState::Unprovisioned:
return To == HubInstanceState::Provisioning;
case HubInstanceState::Provisioned:
- return To == HubInstanceState::Hibernating || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Crashed;
+ return To == HubInstanceState::Hibernating || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Crashed ||
+ To == HubInstanceState::Obliterating;
case HubInstanceState::Hibernated:
- return To == HubInstanceState::Waking || To == HubInstanceState::Deprovisioning;
+ return To == HubInstanceState::Waking || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Obliterating;
case HubInstanceState::Crashed:
- return To == HubInstanceState::Provisioning || To == HubInstanceState::Deprovisioning || To == HubInstanceState::Recovering;
+ return To == HubInstanceState::Provisioning || To == HubInstanceState::Deprovisioning ||
+ To == HubInstanceState::Recovering || To == HubInstanceState::Obliterating;
case HubInstanceState::Provisioning:
return To == HubInstanceState::Provisioned || To == HubInstanceState::Unprovisioned || To == HubInstanceState::Crashed;
case HubInstanceState::Hibernating:
@@ -1062,11 +1374,15 @@ Hub::UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewS
To == HubInstanceState::Crashed;
case HubInstanceState::Recovering:
return To == HubInstanceState::Provisioned || To == HubInstanceState::Unprovisioned;
+ case HubInstanceState::Obliterating:
+ return To == HubInstanceState::Unprovisioned || To == HubInstanceState::Crashed;
}
return false;
}(m_ActiveInstances[ActiveInstanceIndex].State.load(), NewState));
+ const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now();
m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.store(0);
- m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(std::chrono::system_clock::now());
+ m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.store(Now);
+ m_ActiveInstances[ActiveInstanceIndex].StateChangeTime.store(Now);
return m_ActiveInstances[ActiveInstanceIndex].State.exchange(NewState);
}
@@ -1075,10 +1391,14 @@ Hub::AttemptRecoverInstance(std::string_view ModuleId)
{
StorageServerInstance::ExclusiveLockedPtr Instance;
size_t ActiveInstanceIndex = (size_t)-1;
-
{
RwLock::ExclusiveLockScope _(m_Lock);
+ if (m_ShutdownFlag.load())
+ {
+ return;
+ }
+
auto It = m_InstanceLookup.find(std::string(ModuleId));
if (It == m_InstanceLookup.end())
{
@@ -1173,14 +1493,14 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
StorageServerInstance::SharedLockedPtr&& LockedInstance,
size_t ActiveInstanceIndex)
{
+ const std::string ModuleId(LockedInstance.GetModuleId());
+
HubInstanceState InstanceState = m_ActiveInstances[ActiveInstanceIndex].State.load();
if (LockedInstance.IsRunning())
{
- LockedInstance.UpdateMetrics();
+ m_ActiveInstances[ActiveInstanceIndex].ProcessMetrics.Store(LockedInstance.GetProcessMetrics());
if (InstanceState == HubInstanceState::Provisioned)
{
- const std::string ModuleId(LockedInstance.GetModuleId());
-
const uint16_t Port = LockedInstance.GetBasePort();
const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load();
const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load();
@@ -1260,8 +1580,7 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
else if (InstanceState == HubInstanceState::Provisioned)
{
// Process is not running but state says it should be - instance died unexpectedly.
- const std::string ModuleId(LockedInstance.GetModuleId());
- const uint16_t Port = LockedInstance.GetBasePort();
+ const uint16_t Port = LockedInstance.GetBasePort();
UpdateInstanceState(LockedInstance, ActiveInstanceIndex, HubInstanceState::Crashed);
NotifyStateUpdate(ModuleId, HubInstanceState::Provisioned, HubInstanceState::Crashed, Port, {});
LockedInstance = {};
@@ -1272,7 +1591,6 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
{
// Process is not running - no HTTP activity check is possible.
// Use a pure time-based check; the margin window does not apply here.
- const std::string ModuleId = std::string(LockedInstance.GetModuleId());
const std::chrono::system_clock::time_point LastActivityTime = m_ActiveInstances[ActiveInstanceIndex].LastActivityTime.load();
const uint64_t PreviousActivitySum = m_ActiveInstances[ActiveInstanceIndex].LastKnownActivitySum.load();
const std::chrono::system_clock::time_point Now = std::chrono::system_clock::now();
@@ -1304,7 +1622,7 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
}
else
{
- // transitional state (Provisioning, Deprovisioning, Hibernating, Waking, Recovering) - expected, skip.
+ // transitional state (Provisioning, Deprovisioning, Hibernating, Waking, Recovering, Obliterating) - expected, skip.
// Crashed is handled above via AttemptRecoverInstance; it appears here only when the instance
// lock was busy on a previous cycle and recovery is already pending.
return true;
@@ -1312,6 +1630,43 @@ Hub::CheckInstanceStatus(HttpClient& ActivityCheckClient,
}
void
+Hub::UpdateMachineMetrics()
+{
+ try
+ {
+ bool DiskSpaceOk = false;
+ DiskSpace Disk;
+
+ std::filesystem::path ChildDir = m_RunEnvironment.GetChildBaseDir();
+ if (!ChildDir.empty())
+ {
+ if (DiskSpaceInfo(ChildDir, Disk))
+ {
+ DiskSpaceOk = true;
+ }
+ else
+ {
+ ZEN_WARN("Failed to query disk space for '{}'; disk-based provisioning limits will not be enforced", ChildDir);
+ }
+ }
+
+ SystemMetrics Metrics = GetSystemMetrics();
+
+ m_Lock.WithExclusiveLock([&]() {
+ if (DiskSpaceOk)
+ {
+ m_DiskSpace = Disk;
+ }
+ m_SystemMetrics = Metrics;
+ });
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("Failed to update machine metrics. Reason: {}", Ex.what());
+ }
+}
+
+void
Hub::WatchDog()
{
const uint64_t CycleIntervalMs = std::chrono::duration_cast<std::chrono::milliseconds>(m_Config.WatchDog.CycleInterval).count();
@@ -1326,16 +1681,18 @@ Hub::WatchDog()
[&]() -> bool { return m_WatchDogEvent.Wait(0); });
size_t CheckInstanceIndex = SIZE_MAX; // first increment wraps to 0
- while (!m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs)))
+ while (!m_ShutdownFlag.load() && !m_WatchDogEvent.Wait(gsl::narrow<int>(CycleIntervalMs)))
{
try
{
+ UpdateMachineMetrics();
+
// Snapshot slot count. We iterate all slots (including freed nulls) so
// round-robin coverage is not skewed by deprovisioned entries.
size_t SlotsRemaining = m_Lock.WithSharedLock([this]() { return m_ActiveInstances.size(); });
Stopwatch Timer;
- bool ShuttingDown = false;
+ bool ShuttingDown = m_ShutdownFlag.load();
while (SlotsRemaining > 0 && Timer.GetElapsedTimeMs() < CycleProcessingBudgetMs && !ShuttingDown)
{
StorageServerInstance::SharedLockedPtr LockedInstance;
@@ -1366,16 +1723,24 @@ Hub::WatchDog()
std::string ModuleId(LockedInstance.GetModuleId());
- bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex);
- if (InstanceIsOk)
+ try
{
- ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs));
+ bool InstanceIsOk = CheckInstanceStatus(ActivityCheckClient, std::move(LockedInstance), CheckInstanceIndex);
+ if (InstanceIsOk)
+ {
+ ShuttingDown = m_WatchDogEvent.Wait(gsl::narrow<int>(InstanceCheckThrottleMs));
+ }
+ else
+ {
+ ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId);
+ AttemptRecoverInstance(ModuleId);
+ }
}
- else
+ catch (const std::exception& Ex)
{
- ZEN_WARN("Instance for module '{}' is not running, attempting recovery", ModuleId);
- AttemptRecoverInstance(ModuleId);
+ ZEN_WARN("Failed to check status of module {}. Reason: {}", ModuleId, Ex.what());
}
+ ShuttingDown |= m_ShutdownFlag.load();
}
}
catch (const std::exception& Ex)
@@ -1417,6 +1782,14 @@ static const HttpClientSettings kFastTimeout{.ConnectTimeout = std::chrono::mill
namespace hub_testutils {
+ struct TestHubPools
+ {
+ WorkerThreadPool ProvisionPool;
+ WorkerThreadPool HydrationPool;
+
+ explicit TestHubPools(int ThreadCount) : ProvisionPool(ThreadCount, "hub_test_prov"), HydrationPool(ThreadCount, "hub_test_hydr") {}
+ };
+
ZenServerEnvironment MakeHubEnvironment(const std::filesystem::path& BaseDir)
{
return ZenServerEnvironment(ZenServerEnvironment::Hub, GetRunningExecutablePath().parent_path(), BaseDir);
@@ -1425,9 +1798,14 @@ namespace hub_testutils {
std::unique_ptr<Hub> MakeHub(const std::filesystem::path& BaseDir,
Hub::Configuration Config = {},
Hub::AsyncModuleStateChangeCallbackFunc StateChangeCallback = {},
- WorkerThreadPool* WorkerPool = nullptr)
+ TestHubPools* Pools = nullptr)
{
- return std::make_unique<Hub>(Config, MakeHubEnvironment(BaseDir), WorkerPool, std::move(StateChangeCallback));
+ if (Pools)
+ {
+ Config.OptionalProvisionWorkerPool = &Pools->ProvisionPool;
+ Config.OptionalHydrationWorkerPool = &Pools->HydrationPool;
+ }
+ return std::make_unique<Hub>(Config, MakeHubEnvironment(BaseDir), std::move(StateChangeCallback));
}
struct CallbackRecord
@@ -1499,14 +1877,32 @@ namespace hub_testutils {
} // namespace hub_testutils
-TEST_CASE("hub.provision_basic")
+TEST_CASE("hub.provision")
{
ScopedTemporaryDirectory TempDir;
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path());
+
+ struct TransitionRecord
+ {
+ HubInstanceState OldState;
+ HubInstanceState NewState;
+ };
+ RwLock CaptureMutex;
+ std::vector<TransitionRecord> Transitions;
+
+ hub_testutils::StateChangeCapture CaptureInstance;
+
+ auto CaptureFunc =
+ [&](std::string_view ModuleId, const HubProvisionedInstanceInfo& Info, HubInstanceState OldState, HubInstanceState NewState) {
+ CaptureMutex.WithExclusiveLock([&]() { Transitions.push_back({OldState, NewState}); });
+ CaptureInstance.CaptureFunc()(ModuleId, Info, OldState, NewState);
+ };
+
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}, std::move(CaptureFunc));
CHECK_EQ(HubInstance->GetInstanceCount(), 0);
CHECK_FALSE(HubInstance->Find("module_a"));
+ // Provision
HubProvisionedInstanceInfo Info;
const Hub::Response ProvisionResult = HubInstance->Provision("module_a", Info);
REQUIRE_MESSAGE(ProvisionResult.ResponseCode == Hub::EResponseCode::Completed, ProvisionResult.Message);
@@ -1515,12 +1911,23 @@ TEST_CASE("hub.provision_basic")
Hub::InstanceInfo InstanceInfo;
REQUIRE(HubInstance->Find("module_a", &InstanceInfo));
CHECK_EQ(InstanceInfo.State, HubInstanceState::Provisioned);
+ CHECK_NE(InstanceInfo.StateChangeTime, std::chrono::system_clock::time_point::min());
+ CHECK_LE(InstanceInfo.StateChangeTime, std::chrono::system_clock::now());
{
HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout);
CHECK(ModClient.Get("/health/"));
}
+ // Verify provision callback
+ {
+ RwLock::SharedLockScope _(CaptureInstance.CallbackMutex);
+ REQUIRE_EQ(CaptureInstance.ProvisionCallbacks.size(), 1u);
+ CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].ModuleId, "module_a");
+ CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].Port, Info.Port);
+ }
+
+ // Deprovision
const Hub::Response DeprovisionResult = HubInstance->Deprovision("module_a");
CHECK(DeprovisionResult.ResponseCode == Hub::EResponseCode::Completed);
CHECK_EQ(HubInstance->GetInstanceCount(), 0);
@@ -1530,6 +1937,28 @@ TEST_CASE("hub.provision_basic")
HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout);
CHECK(!ModClient.Get("/health/"));
}
+
+ // Verify deprovision callback
+ {
+ RwLock::SharedLockScope _(CaptureInstance.CallbackMutex);
+ REQUIRE_EQ(CaptureInstance.DeprovisionCallbacks.size(), 1u);
+ CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].ModuleId, "module_a");
+ CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].Port, Info.Port);
+ }
+
+ // Verify full transition sequence
+ {
+ RwLock::SharedLockScope _(CaptureMutex);
+ REQUIRE_EQ(Transitions.size(), 4u);
+ CHECK_EQ(Transitions[0].OldState, HubInstanceState::Unprovisioned);
+ CHECK_EQ(Transitions[0].NewState, HubInstanceState::Provisioning);
+ CHECK_EQ(Transitions[1].OldState, HubInstanceState::Provisioning);
+ CHECK_EQ(Transitions[1].NewState, HubInstanceState::Provisioned);
+ CHECK_EQ(Transitions[2].OldState, HubInstanceState::Provisioned);
+ CHECK_EQ(Transitions[2].NewState, HubInstanceState::Deprovisioning);
+ CHECK_EQ(Transitions[3].OldState, HubInstanceState::Deprovisioning);
+ CHECK_EQ(Transitions[3].NewState, HubInstanceState::Unprovisioned);
+ }
}
TEST_CASE("hub.provision_config")
@@ -1582,92 +2011,6 @@ TEST_CASE("hub.provision_config")
}
}
-TEST_CASE("hub.provision_callbacks")
-{
- ScopedTemporaryDirectory TempDir;
-
- hub_testutils::StateChangeCapture CaptureInstance;
-
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}, CaptureInstance.CaptureFunc());
-
- HubProvisionedInstanceInfo Info;
-
- const Hub::Response ProvisionResult = HubInstance->Provision("cb_module", Info);
- REQUIRE_MESSAGE(ProvisionResult.ResponseCode == Hub::EResponseCode::Completed, ProvisionResult.Message);
-
- {
- RwLock::SharedLockScope _(CaptureInstance.CallbackMutex);
- REQUIRE_EQ(CaptureInstance.ProvisionCallbacks.size(), 1u);
- CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].ModuleId, "cb_module");
- CHECK_EQ(CaptureInstance.ProvisionCallbacks[0].Port, Info.Port);
- CHECK_NE(CaptureInstance.ProvisionCallbacks[0].Port, 0);
- }
-
- {
- HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout);
- CHECK(ModClient.Get("/health/"));
- }
-
- const Hub::Response DeprovisionResult = HubInstance->Deprovision("cb_module");
- CHECK(DeprovisionResult.ResponseCode == Hub::EResponseCode::Completed);
-
- {
- HttpClient ModClient(fmt::format("http://localhost:{}", Info.Port), kFastTimeout);
- CHECK(!ModClient.Get("/health/"));
- }
-
- {
- RwLock::SharedLockScope _(CaptureInstance.CallbackMutex);
- REQUIRE_EQ(CaptureInstance.DeprovisionCallbacks.size(), 1u);
- CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].ModuleId, "cb_module");
- CHECK_EQ(CaptureInstance.DeprovisionCallbacks[0].Port, Info.Port);
- CHECK_EQ(CaptureInstance.DeprovisionCallbacks.size(), 1u);
- }
-}
-
-TEST_CASE("hub.provision_callback_sequence")
-{
- ScopedTemporaryDirectory TempDir;
-
- struct TransitionRecord
- {
- HubInstanceState OldState;
- HubInstanceState NewState;
- };
- RwLock CaptureMutex;
- std::vector<TransitionRecord> Transitions;
-
- auto CaptureFunc =
- [&](std::string_view ModuleId, const HubProvisionedInstanceInfo& Info, HubInstanceState OldState, HubInstanceState NewState) {
- ZEN_UNUSED(ModuleId);
- ZEN_UNUSED(Info);
- CaptureMutex.WithExclusiveLock([&]() { Transitions.push_back({OldState, NewState}); });
- };
-
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {}, std::move(CaptureFunc));
-
- HubProvisionedInstanceInfo Info;
- {
- const Hub::Response R = HubInstance->Provision("seq_module", Info);
- REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
- }
- {
- const Hub::Response R = HubInstance->Deprovision("seq_module");
- REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
- }
-
- RwLock::SharedLockScope _(CaptureMutex);
- REQUIRE_EQ(Transitions.size(), 4u);
- CHECK_EQ(Transitions[0].OldState, HubInstanceState::Unprovisioned);
- CHECK_EQ(Transitions[0].NewState, HubInstanceState::Provisioning);
- CHECK_EQ(Transitions[1].OldState, HubInstanceState::Provisioning);
- CHECK_EQ(Transitions[1].NewState, HubInstanceState::Provisioned);
- CHECK_EQ(Transitions[2].OldState, HubInstanceState::Provisioned);
- CHECK_EQ(Transitions[2].NewState, HubInstanceState::Deprovisioning);
- CHECK_EQ(Transitions[3].OldState, HubInstanceState::Deprovisioning);
- CHECK_EQ(Transitions[3].NewState, HubInstanceState::Unprovisioned);
-}
-
TEST_CASE("hub.instance_limit")
{
ScopedTemporaryDirectory TempDir;
@@ -1699,54 +2042,7 @@ TEST_CASE("hub.instance_limit")
CHECK_EQ(HubInstance->GetInstanceCount(), 2);
}
-TEST_CASE("hub.enumerate_modules")
-{
- ScopedTemporaryDirectory TempDir;
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path());
-
- HubProvisionedInstanceInfo Info;
-
- {
- const Hub::Response R = HubInstance->Provision("enum_a", Info);
- REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
- }
- {
- const Hub::Response R = HubInstance->Provision("enum_b", Info);
- REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
- }
-
- std::vector<std::string> Ids;
- int ProvisionedCount = 0;
- HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) {
- Ids.push_back(std::string(ModuleId));
- if (InstanceInfo.State == HubInstanceState::Provisioned)
- {
- ProvisionedCount++;
- }
- });
- CHECK_EQ(Ids.size(), 2u);
- CHECK_EQ(ProvisionedCount, 2);
- const bool FoundA = std::find(Ids.begin(), Ids.end(), "enum_a") != Ids.end();
- const bool FoundB = std::find(Ids.begin(), Ids.end(), "enum_b") != Ids.end();
- CHECK(FoundA);
- CHECK(FoundB);
-
- HubInstance->Deprovision("enum_a");
- Ids.clear();
- ProvisionedCount = 0;
- HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) {
- Ids.push_back(std::string(ModuleId));
- if (InstanceInfo.State == HubInstanceState::Provisioned)
- {
- ProvisionedCount++;
- }
- });
- REQUIRE_EQ(Ids.size(), 1u);
- CHECK_EQ(Ids[0], "enum_b");
- CHECK_EQ(ProvisionedCount, 1);
-}
-
-TEST_CASE("hub.max_instance_count")
+TEST_CASE("hub.enumerate_and_instance_tracking")
{
ScopedTemporaryDirectory TempDir;
std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path());
@@ -1756,22 +2052,56 @@ TEST_CASE("hub.max_instance_count")
HubProvisionedInstanceInfo Info;
{
- const Hub::Response R = HubInstance->Provision("max_a", Info);
+ const Hub::Response R = HubInstance->Provision("track_a", Info);
REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
}
CHECK_GE(HubInstance->GetMaxInstanceCount(), 1);
{
- const Hub::Response R = HubInstance->Provision("max_b", Info);
+ const Hub::Response R = HubInstance->Provision("track_b", Info);
REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
}
CHECK_GE(HubInstance->GetMaxInstanceCount(), 2);
+ // Enumerate both modules
+ {
+ std::vector<std::string> Ids;
+ int ProvisionedCount = 0;
+ HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) {
+ Ids.push_back(std::string(ModuleId));
+ if (InstanceInfo.State == HubInstanceState::Provisioned)
+ {
+ ProvisionedCount++;
+ }
+ });
+ CHECK_EQ(Ids.size(), 2u);
+ CHECK_EQ(ProvisionedCount, 2);
+ CHECK(std::find(Ids.begin(), Ids.end(), "track_a") != Ids.end());
+ CHECK(std::find(Ids.begin(), Ids.end(), "track_b") != Ids.end());
+ }
+
const int MaxAfterTwo = HubInstance->GetMaxInstanceCount();
- HubInstance->Deprovision("max_a");
+ // Deprovision one - max instance count must not decrease
+ HubInstance->Deprovision("track_a");
CHECK_EQ(HubInstance->GetInstanceCount(), 1);
CHECK_EQ(HubInstance->GetMaxInstanceCount(), MaxAfterTwo);
+
+ // Enumerate after deprovision
+ {
+ std::vector<std::string> Ids;
+ int ProvisionedCount = 0;
+ HubInstance->EnumerateModules([&](std::string_view ModuleId, const Hub::InstanceInfo& InstanceInfo) {
+ Ids.push_back(std::string(ModuleId));
+ if (InstanceInfo.State == HubInstanceState::Provisioned)
+ {
+ ProvisionedCount++;
+ }
+ });
+ REQUIRE_EQ(Ids.size(), 1u);
+ CHECK_EQ(Ids[0], "track_b");
+ CHECK_EQ(ProvisionedCount, 1);
+ }
}
TEST_CASE("hub.concurrent_callbacks")
@@ -1917,7 +2247,7 @@ TEST_CASE("hub.job_object")
}
# endif // ZEN_PLATFORM_WINDOWS
-TEST_CASE("hub.hibernate_wake")
+TEST_CASE("hub.hibernate_wake_obliterate")
{
ScopedTemporaryDirectory TempDir;
Hub::Configuration Config;
@@ -1927,6 +2257,11 @@ TEST_CASE("hub.hibernate_wake")
HubProvisionedInstanceInfo ProvInfo;
Hub::InstanceInfo Info;
+ // Error cases on non-existent modules (no provision needed)
+ CHECK(HubInstance->Hibernate("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound);
+ CHECK(HubInstance->Wake("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound);
+ CHECK(HubInstance->Deprovision("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound);
+
// Provision
{
const Hub::Response R = HubInstance->Provision("hib_a", ProvInfo);
@@ -1934,82 +2269,104 @@ TEST_CASE("hub.hibernate_wake")
}
REQUIRE(HubInstance->Find("hib_a", &Info));
CHECK_EQ(Info.State, HubInstanceState::Provisioned);
+ const std::chrono::system_clock::time_point ProvisionedTime = Info.StateChangeTime;
+ CHECK_NE(ProvisionedTime, std::chrono::system_clock::time_point::min());
+ CHECK_LE(ProvisionedTime, std::chrono::system_clock::now());
{
HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
CHECK(ModClient.Get("/health/"));
}
+ // Double-wake on provisioned module is idempotent
+ CHECK(HubInstance->Wake("hib_a").ResponseCode == Hub::EResponseCode::Completed);
+
// Hibernate
- const Hub::Response HibernateResult = HubInstance->Hibernate("hib_a");
- REQUIRE_MESSAGE(HibernateResult.ResponseCode == Hub::EResponseCode::Completed, HibernateResult.Message);
+ {
+ const Hub::Response R = HubInstance->Hibernate("hib_a");
+ REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
+ }
REQUIRE(HubInstance->Find("hib_a", &Info));
CHECK_EQ(Info.State, HubInstanceState::Hibernated);
+ const std::chrono::system_clock::time_point HibernatedTime = Info.StateChangeTime;
+ CHECK_GE(HibernatedTime, ProvisionedTime);
{
HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
CHECK(!ModClient.Get("/health/"));
}
+ // Double-hibernate on already-hibernated module is idempotent
+ CHECK(HubInstance->Hibernate("hib_a").ResponseCode == Hub::EResponseCode::Completed);
+
// Wake
- const Hub::Response WakeResult = HubInstance->Wake("hib_a");
- REQUIRE_MESSAGE(WakeResult.ResponseCode == Hub::EResponseCode::Completed, WakeResult.Message);
+ {
+ const Hub::Response R = HubInstance->Wake("hib_a");
+ REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
+ }
REQUIRE(HubInstance->Find("hib_a", &Info));
CHECK_EQ(Info.State, HubInstanceState::Provisioned);
+ CHECK_GE(Info.StateChangeTime, HibernatedTime);
{
HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
CHECK(ModClient.Get("/health/"));
}
- // Deprovision
- const Hub::Response DeprovisionResult = HubInstance->Deprovision("hib_a");
- CHECK(DeprovisionResult.ResponseCode == Hub::EResponseCode::Completed);
- CHECK_FALSE(HubInstance->Find("hib_a"));
+ // Hibernate again for obliterate-from-hibernated test
{
- HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
- CHECK(!ModClient.Get("/health/"));
+ const Hub::Response R = HubInstance->Hibernate("hib_a");
+ REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
}
-}
-
-TEST_CASE("hub.hibernate_wake_errors")
-{
- ScopedTemporaryDirectory TempDir;
- Hub::Configuration Config;
- Config.BasePortNumber = 22700;
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config);
-
- HubProvisionedInstanceInfo ProvInfo;
+ REQUIRE(HubInstance->Find("hib_a", &Info));
+ CHECK_EQ(Info.State, HubInstanceState::Hibernated);
- // Hibernate/wake on a non-existent module - returns NotFound (-> 404)
- CHECK(HubInstance->Hibernate("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound);
- CHECK(HubInstance->Wake("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound);
+ // Obliterate from hibernated
+ {
+ const Hub::Response R = HubInstance->Obliterate("hib_a");
+ CHECK(R.ResponseCode == Hub::EResponseCode::Completed);
+ }
+ CHECK_EQ(HubInstance->GetInstanceCount(), 0);
+ CHECK_FALSE(HubInstance->Find("hib_a"));
- // Double-hibernate: second hibernate on already-hibernated module returns Completed (idempotent)
+ // Re-provision for obliterate-from-provisioned test
{
- const Hub::Response R = HubInstance->Provision("err_b", ProvInfo);
+ const Hub::Response R = HubInstance->Provision("hib_a", ProvInfo);
REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
}
+ REQUIRE(HubInstance->Find("hib_a", &Info));
+ CHECK_EQ(Info.State, HubInstanceState::Provisioned);
{
- const Hub::Response R = HubInstance->Hibernate("err_b");
- REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
+ HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
+ CHECK(ModClient.Get("/health/"));
}
+ // Obliterate from provisioned
+ {
+ const Hub::Response R = HubInstance->Obliterate("hib_a");
+ CHECK(R.ResponseCode == Hub::EResponseCode::Completed);
+ }
+ CHECK_EQ(HubInstance->GetInstanceCount(), 0);
+ CHECK_FALSE(HubInstance->Find("hib_a"));
{
- const Hub::Response HibResp = HubInstance->Hibernate("err_b");
- CHECK(HibResp.ResponseCode == Hub::EResponseCode::Completed);
+ HttpClient ModClient(fmt::format("http://localhost:{}", ProvInfo.Port), kFastTimeout);
+ CHECK(!ModClient.Get("/health/"));
}
- // Wake on provisioned: succeeds (-> Provisioned), then wake again returns Completed (idempotent)
+ // Obliterate deprovisioned module (not tracked by hub, backend data may exist)
{
- const Hub::Response R = HubInstance->Wake("err_b");
+ const Hub::Response R = HubInstance->Provision("hib_a", ProvInfo);
REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
}
-
{
- const Hub::Response WakeResp = HubInstance->Wake("err_b");
- CHECK(WakeResp.ResponseCode == Hub::EResponseCode::Completed);
+ const Hub::Response R = HubInstance->Deprovision("hib_a");
+ REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
+ }
+ CHECK_FALSE(HubInstance->Find("hib_a"));
+ {
+ const Hub::Response R = HubInstance->Obliterate("hib_a");
+ CHECK(R.ResponseCode == Hub::EResponseCode::Completed);
}
- // Deprovision not-found - returns NotFound (-> 404)
- CHECK(HubInstance->Deprovision("never_provisioned").ResponseCode == Hub::EResponseCode::NotFound);
+ // Obliterate of a never-provisioned module also succeeds (no-op backend cleanup)
+ CHECK(HubInstance->Obliterate("never_existed").ResponseCode == Hub::EResponseCode::Completed);
}
TEST_CASE("hub.async_hibernate_wake")
@@ -2019,8 +2376,8 @@ TEST_CASE("hub.async_hibernate_wake")
Hub::Configuration Config;
Config.BasePortNumber = 23000;
- WorkerThreadPool WorkerPool(2, "hub_async_hib_wake");
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool);
+ hub_testutils::TestHubPools Pools(2);
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools);
HubProvisionedInstanceInfo ProvInfo;
Hub::InstanceInfo Info;
@@ -2150,25 +2507,21 @@ TEST_CASE("hub.recover_process_crash")
if (HubInstance->Find("module_a", &InstanceInfo) && InstanceInfo.State == HubInstanceState::Provisioned &&
ModClient.Get("/health/"))
{
- // Recovery must reuse the same port - the instance was never removed from the hub's
- // port table during recovery, so AttemptRecoverInstance reuses m_Config.BasePort.
CHECK_EQ(InstanceInfo.Port, Info.Port);
Recovered = true;
break;
}
}
- CHECK_MESSAGE(Recovered, "Instance did not recover within timeout");
+ REQUIRE_MESSAGE(Recovered, "Instance did not recover within timeout");
// Verify the full crash/recovery callback sequence
{
RwLock::SharedLockScope _(CaptureMutex);
REQUIRE_GE(Transitions.size(), 3u);
- // Find the Provisioned->Crashed transition
const auto CrashedIt = std::find_if(Transitions.begin(), Transitions.end(), [](const TransitionRecord& R) {
return R.OldState == HubInstanceState::Provisioned && R.NewState == HubInstanceState::Crashed;
});
REQUIRE_NE(CrashedIt, Transitions.end());
- // Recovery sequence follows: Crashed->Recovering, Recovering->Provisioned
const auto RecoveringIt = CrashedIt + 1;
REQUIRE_NE(RecoveringIt, Transitions.end());
CHECK_EQ(RecoveringIt->OldState, HubInstanceState::Crashed);
@@ -2178,44 +2531,6 @@ TEST_CASE("hub.recover_process_crash")
CHECK_EQ(RecoveredIt->OldState, HubInstanceState::Recovering);
CHECK_EQ(RecoveredIt->NewState, HubInstanceState::Provisioned);
}
-}
-
-TEST_CASE("hub.recover_process_crash_then_deprovision")
-{
- ScopedTemporaryDirectory TempDir;
-
- // Fast watchdog cycle so crash detection is near-instant instead of waiting up to the 3s default.
- Hub::Configuration Config;
- Config.WatchDog.CycleInterval = std::chrono::milliseconds(10);
- Config.WatchDog.InstanceCheckThrottle = std::chrono::milliseconds(1);
-
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config);
-
- HubProvisionedInstanceInfo Info;
- {
- const Hub::Response R = HubInstance->Provision("module_a", Info);
- REQUIRE_MESSAGE(R.ResponseCode == Hub::EResponseCode::Completed, R.Message);
- }
-
- // Kill the child process, wait for the watchdog to detect and recover the instance.
- HubInstance->TerminateModuleForTesting("module_a");
-
- constexpr auto kPollIntervalMs = std::chrono::milliseconds(50);
- constexpr auto kTimeoutMs = std::chrono::seconds(15);
- const auto Deadline = std::chrono::steady_clock::now() + kTimeoutMs;
-
- bool Recovered = false;
- while (std::chrono::steady_clock::now() < Deadline)
- {
- std::this_thread::sleep_for(kPollIntervalMs);
- Hub::InstanceInfo InstanceInfo;
- if (HubInstance->Find("module_a", &InstanceInfo) && InstanceInfo.State == HubInstanceState::Provisioned)
- {
- Recovered = true;
- break;
- }
- }
- REQUIRE_MESSAGE(Recovered, "Instance did not recover within timeout");
// After recovery, deprovision should succeed and a re-provision should work.
{
@@ -2244,8 +2559,8 @@ TEST_CASE("hub.async_provision_concurrent")
Config.BasePortNumber = 22800;
Config.InstanceLimit = kModuleCount;
- WorkerThreadPool WorkerPool(4, "hub_async_concurrent");
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool);
+ hub_testutils::TestHubPools Pools(4);
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools);
std::vector<HubProvisionedInstanceInfo> Infos(kModuleCount);
std::vector<std::string> Reasons(kModuleCount);
@@ -2326,8 +2641,8 @@ TEST_CASE("hub.async_provision_shutdown_waits")
Config.InstanceLimit = kModuleCount;
Config.BasePortNumber = 22900;
- WorkerThreadPool WorkerPool(2, "hub_async_shutdown");
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool);
+ hub_testutils::TestHubPools Pools(2);
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools);
std::vector<HubProvisionedInstanceInfo> Infos(kModuleCount);
@@ -2352,15 +2667,15 @@ TEST_CASE("hub.async_provision_shutdown_waits")
TEST_CASE("hub.async_provision_rejected")
{
- // Rejection from CanProvisionInstance fires synchronously even when a WorkerPool is present.
+ // Rejection from CanProvisionInstanceLocked fires synchronously even when a WorkerPool is present.
ScopedTemporaryDirectory TempDir;
Hub::Configuration Config;
Config.InstanceLimit = 1;
Config.BasePortNumber = 23100;
- WorkerThreadPool WorkerPool(2, "hub_async_rejected");
- std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &WorkerPool);
+ hub_testutils::TestHubPools Pools(2);
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config, {}, &Pools);
HubProvisionedInstanceInfo Info;
@@ -2369,7 +2684,7 @@ TEST_CASE("hub.async_provision_rejected")
REQUIRE_MESSAGE(FirstResult.ResponseCode == Hub::EResponseCode::Accepted, FirstResult.Message);
REQUIRE_NE(Info.Port, 0);
- // Second provision: CanProvisionInstance rejects synchronously (limit reached), returns Rejected
+ // Second provision: CanProvisionInstanceLocked rejects synchronously (limit reached), returns Rejected
HubProvisionedInstanceInfo Info2;
const Hub::Response SecondResult = HubInstance->Provision("async_r2", Info2);
CHECK(SecondResult.ResponseCode == Hub::EResponseCode::Rejected);
@@ -2448,12 +2763,12 @@ TEST_CASE("hub.instance.inactivity.deprovision")
// Phase 1: immediately after setup all three instances must still be alive.
// No timeout has elapsed yet (only 100ms have passed).
- CHECK_MESSAGE(HubInstance->Find("idle"), "idle was deprovisioned within 100ms - its 2s provisioned timeout has not elapsed");
+ CHECK_MESSAGE(HubInstance->Find("idle"), "idle was deprovisioned within 100ms - its 4s provisioned timeout has not elapsed");
CHECK_MESSAGE(HubInstance->Find("idle_hib"), "idle_hib was deprovisioned within 100ms - its 1s hibernated timeout has not elapsed");
CHECK_MESSAGE(HubInstance->Find("persistent"),
- "persistent was deprovisioned within 100ms - its 2s provisioned timeout has not elapsed");
+ "persistent was deprovisioned within 100ms - its 4s provisioned timeout has not elapsed");
// Phase 2: idle_hib must be deprovisioned by the watchdog within its 1s hibernated timeout.
// idle must remain alive - its 2s provisioned timeout has not elapsed yet.
@@ -2477,7 +2792,7 @@ TEST_CASE("hub.instance.inactivity.deprovision")
CHECK_MESSAGE(!HubInstance->Find("idle_hib"), "idle_hib should still be gone - it was deprovisioned in phase 2");
- CHECK_MESSAGE(!HubInstance->Find("idle"), "idle should be gone after its 3s provisioned timeout elapsed");
+ CHECK_MESSAGE(!HubInstance->Find("idle"), "idle should be gone after its 4s provisioned timeout elapsed");
CHECK_MESSAGE(HubInstance->Find("persistent"),
"persistent was incorrectly deprovisioned - its activity timer was reset by PokeInstance");
@@ -2485,6 +2800,55 @@ TEST_CASE("hub.instance.inactivity.deprovision")
HubInstance->Shutdown();
}
+TEST_CASE("hub.machine_metrics")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), {});
+
+ // UpdateMachineMetrics() is called synchronously in the Hub constructor, so metrics
+ // are available immediately without waiting for a watchdog cycle.
+ SystemMetrics SysMetrics;
+ DiskSpace Disk;
+ HubInstance->GetMachineMetrics(SysMetrics, Disk);
+
+ CHECK_GT(Disk.Total, 0u);
+ CHECK_LE(Disk.Free, Disk.Total);
+
+ CHECK_GT(SysMetrics.SystemMemoryMiB, 0u);
+ CHECK_LE(SysMetrics.AvailSystemMemoryMiB, SysMetrics.SystemMemoryMiB);
+
+ CHECK_GT(SysMetrics.VirtualMemoryMiB, 0u);
+ CHECK_LE(SysMetrics.AvailVirtualMemoryMiB, SysMetrics.VirtualMemoryMiB);
+}
+
+TEST_CASE("hub.provision_rejected_resource_limits")
+{
+ // The Hub constructor calls UpdateMachineMetrics() synchronously, so CanProvisionInstanceLocked
+ // can enforce limits immediately without waiting for a watchdog cycle.
+ ScopedTemporaryDirectory TempDir;
+
+ {
+ Hub::Configuration Config;
+ Config.ResourceLimits.DiskUsageBytes = 1;
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config);
+ HubProvisionedInstanceInfo Info;
+ const Hub::Response Result = HubInstance->Provision("disk_limit", Info);
+ CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected);
+ CHECK_NE(Result.Message.find("disk usage"), std::string::npos);
+ }
+
+ {
+ Hub::Configuration Config;
+ Config.ResourceLimits.MemoryUsageBytes = 1;
+ std::unique_ptr<Hub> HubInstance = hub_testutils::MakeHub(TempDir.Path(), Config);
+ HubProvisionedInstanceInfo Info;
+ const Hub::Response Result = HubInstance->Provision("mem_limit", Info);
+ CHECK(Result.ResponseCode == Hub::EResponseCode::Rejected);
+ CHECK_NE(Result.Message.find("ram usage"), std::string::npos);
+ }
+}
+
TEST_SUITE_END();
void
diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h
index c343b19e2..040f34af5 100644
--- a/src/zenserver/hub/hub.h
+++ b/src/zenserver/hub/hub.h
@@ -6,6 +6,8 @@
#include "resourcemetrics.h"
#include "storageserverinstance.h"
+#include <zencore/compactbinary.h>
+#include <zencore/filesystem.h>
#include <zencore/system.h>
#include <zenutil/zenserverprocess.h>
@@ -16,6 +18,7 @@
#include <memory>
#include <thread>
#include <unordered_map>
+#include <unordered_set>
namespace zen {
@@ -64,10 +67,20 @@ public:
uint32_t InstanceHttpThreadCount = 0; // Automatic
int InstanceCoreLimit = 0; // Automatic
+ std::string InstanceMalloc;
+ std::string InstanceTrace;
+ std::string InstanceTraceHost;
+ std::string InstanceTraceFile;
std::filesystem::path InstanceConfigPath;
std::string HydrationTargetSpecification;
+ CbObject HydrationOptions;
WatchDogConfiguration WatchDog;
+
+ ResourceMetrics ResourceLimits;
+
+ WorkerThreadPool* OptionalProvisionWorkerPool = nullptr;
+ WorkerThreadPool* OptionalHydrationWorkerPool = nullptr;
};
typedef std::function<
@@ -76,7 +89,6 @@ public:
Hub(const Configuration& Config,
ZenServerEnvironment&& RunEnvironment,
- WorkerThreadPool* OptionalWorkerPool = nullptr,
AsyncModuleStateChangeCallbackFunc&& ModuleStateChangeCallback = {});
~Hub();
@@ -86,7 +98,7 @@ public:
struct InstanceInfo
{
HubInstanceState State = HubInstanceState::Unprovisioned;
- std::chrono::system_clock::time_point ProvisionTime;
+ std::chrono::system_clock::time_point StateChangeTime;
ProcessMetrics Metrics;
uint16_t Port = 0;
};
@@ -126,6 +138,14 @@ public:
Response Deprovision(const std::string& ModuleId);
/**
+ * Obliterate a storage server instance and all associated data.
+ * Shuts down the process, deletes backend hydration data, and cleans local state.
+ *
+ * @param ModuleId The ID of the module to obliterate.
+ */
+ Response Obliterate(const std::string& ModuleId);
+
+ /**
* Hibernate a storage server instance for the given module ID.
* The instance is shut down but its data is preserved; it can be woken later.
*
@@ -160,6 +180,10 @@ public:
int GetMaxInstanceCount() const { return m_MaxInstanceCount.load(); }
+ void GetMachineMetrics(SystemMetrics& OutSystemMetrict, DiskSpace& OutDiskSpace) const;
+
+ bool IsInstancePort(uint16_t Port) const;
+
const Configuration& GetConfig() const { return m_Config; }
#if ZEN_WITH_TESTS
@@ -176,14 +200,31 @@ private:
AsyncModuleStateChangeCallbackFunc m_ModuleStateChangeCallback;
std::string m_HydrationTargetSpecification;
+ CbObject m_HydrationOptions;
std::filesystem::path m_HydrationTempPath;
#if ZEN_PLATFORM_WINDOWS
JobObject m_JobObject;
#endif
- RwLock m_Lock;
+ mutable RwLock m_Lock;
std::unordered_map<std::string, size_t> m_InstanceLookup;
+ // Mirrors ProcessMetrics with atomic fields, enabling lock-free reads alongside watchdog writes.
+ struct AtomicProcessMetrics
+ {
+ std::atomic<uint64_t> MemoryBytes = 0;
+ std::atomic<uint64_t> KernelTimeMs = 0;
+ std::atomic<uint64_t> UserTimeMs = 0;
+ std::atomic<uint64_t> WorkingSetSize = 0;
+ std::atomic<uint64_t> PeakWorkingSetSize = 0;
+ std::atomic<uint64_t> PagefileUsage = 0;
+ std::atomic<uint64_t> PeakPagefileUsage = 0;
+
+ ProcessMetrics Load() const;
+ void Store(const ProcessMetrics& Metrics);
+ void Reset();
+ };
+
struct ActiveInstance
{
// Invariant: Instance == nullptr if and only if State == Unprovisioned.
@@ -192,11 +233,16 @@ private:
// without holding the hub lock.
std::unique_ptr<StorageServerInstance> Instance;
std::atomic<HubInstanceState> State = HubInstanceState::Unprovisioned;
- // TODO: We should move current metrics here (from StorageServerInstance)
- // Read and updated by WatchDog, updates to State triggers a reset of both
+ // Process metrics - written by WatchDog (inside instance shared lock), read lock-free.
+ AtomicProcessMetrics ProcessMetrics;
+
+ // Activity tracking - written by WatchDog, reset on every state transition.
std::atomic<uint64_t> LastKnownActivitySum = 0;
std::atomic<std::chrono::system_clock::time_point> LastActivityTime = std::chrono::system_clock::time_point::min();
+
+ // Set in UpdateInstanceStateLocked on every state transition; read lock-free by Find/EnumerateModules.
+ std::atomic<std::chrono::system_clock::time_point> StateChangeTime = std::chrono::system_clock::time_point::min();
};
// UpdateInstanceState is overloaded to accept a locked instance pointer (exclusive or shared) or the hub exclusive
@@ -224,23 +270,23 @@ private:
}
HubInstanceState UpdateInstanceStateLocked(size_t ActiveInstanceIndex, HubInstanceState NewState);
- std::vector<ActiveInstance> m_ActiveInstances;
- std::deque<size_t> m_FreeActiveInstanceIndexes;
- ResourceMetrics m_ResourceLimits;
- SystemMetrics m_HostMetrics;
- std::atomic<int> m_MaxInstanceCount = 0;
- std::thread m_WatchDog;
+ std::vector<ActiveInstance> m_ActiveInstances;
+ std::deque<size_t> m_FreeActiveInstanceIndexes;
+ SystemMetrics m_SystemMetrics;
+ DiskSpace m_DiskSpace;
+ std::atomic<int> m_MaxInstanceCount = 0;
+ std::thread m_WatchDog;
+ std::unordered_set<std::string> m_ObliteratingInstances;
Event m_WatchDogEvent;
void WatchDog();
+ void UpdateMachineMetrics();
bool CheckInstanceStatus(HttpClient& ActivityHttpClient,
StorageServerInstance::SharedLockedPtr&& LockedInstance,
size_t ActiveInstanceIndex);
void AttemptRecoverInstance(std::string_view ModuleId);
- void UpdateStats();
- void UpdateCapacityMetrics();
- bool CanProvisionInstance(std::string_view ModuleId, std::string& OutReason);
+ bool CanProvisionInstanceLocked(std::string_view ModuleId, std::string& OutReason);
uint16_t GetInstanceIndexAssignedPort(size_t ActiveInstanceIndex) const;
Response InternalDeprovision(const std::string& ModuleId, std::function<bool(ActiveInstance& Instance)>&& DeprovisionGate);
@@ -248,9 +294,12 @@ private:
size_t ActiveInstanceIndex,
HubInstanceState OldState,
bool IsNewInstance);
- void CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex);
- void CompleteHibernate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState);
- void CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState);
+ void CompleteDeprovision(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState);
+ void CompleteObliterate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex);
+ void CompleteHibernate(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState);
+ void CompleteWake(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, HubInstanceState OldState);
+ void RemoveInstance(StorageServerInstance::ExclusiveLockedPtr& Instance, size_t ActiveInstanceIndex, std::string_view ModuleId);
+ void ObliterateBackendData(std::string_view ModuleId);
// Notifications may fire slightly out of sync with the Hub's internal State flag.
// The guarantee is that notifications are sent in the correct order, but the State
diff --git a/src/zenserver/hub/hubinstancestate.cpp b/src/zenserver/hub/hubinstancestate.cpp
index c47fdd294..310305e5d 100644
--- a/src/zenserver/hub/hubinstancestate.cpp
+++ b/src/zenserver/hub/hubinstancestate.cpp
@@ -29,6 +29,8 @@ ToString(HubInstanceState State)
return "crashed";
case HubInstanceState::Recovering:
return "recovering";
+ case HubInstanceState::Obliterating:
+ return "obliterating";
}
ZEN_ASSERT(false);
return "unknown";
diff --git a/src/zenserver/hub/hubinstancestate.h b/src/zenserver/hub/hubinstancestate.h
index c895f75d1..c7188aa5c 100644
--- a/src/zenserver/hub/hubinstancestate.h
+++ b/src/zenserver/hub/hubinstancestate.h
@@ -20,7 +20,8 @@ enum class HubInstanceState : uint32_t
Hibernating, // Provisioned -> Hibernated (Shutting down process, preserving data on disk)
Waking, // Hibernated -> Provisioned (Starting process from preserved data)
Deprovisioning, // Provisioned/Hibernated/Crashed -> Unprovisioned (Shutting down process and cleaning up data)
- Recovering, // Crashed -> Provisioned/Deprovisioned (Attempting in-place restart after a crash)
+ Recovering, // Crashed -> Provisioned/Unprovisioned (Attempting in-place restart after a crash)
+ Obliterating, // Provisioned/Hibernated/Crashed -> Unprovisioned (Destroying all local and backend data)
};
std::string_view ToString(HubInstanceState State);
diff --git a/src/zenserver/hub/hydration.cpp b/src/zenserver/hub/hydration.cpp
index 541127590..b356064f9 100644
--- a/src/zenserver/hub/hydration.cpp
+++ b/src/zenserver/hub/hydration.cpp
@@ -5,20 +5,25 @@
#include <zencore/basicfile.h>
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
#include <zencore/except_fmt.h>
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
+#include <zencore/parallelwork.h>
+#include <zencore/stream.h>
#include <zencore/system.h>
+#include <zencore/timer.h>
#include <zenutil/cloud/imdscredentials.h>
#include <zenutil/cloud/s3client.h>
+#include <zenutil/filesystemutils.h>
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <json11.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
+#include <numeric>
+#include <unordered_map>
+#include <unordered_set>
#if ZEN_WITH_TESTS
-# include <zencore/parallelwork.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
# include <zencore/thread.h>
@@ -29,7 +34,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-namespace {
+namespace hydration_impl {
/// UTC time decomposed to calendar fields with sub-second milliseconds.
struct UtcTime
@@ -55,483 +60,1046 @@ namespace {
}
};
-} // namespace
+ std::filesystem::path FastRelativePath(const std::filesystem::path& Root, const std::filesystem::path& Abs)
+ {
+ auto [_, ItAbs] = std::mismatch(Root.begin(), Root.end(), Abs.begin(), Abs.end());
+ std::filesystem::path RelativePath;
+ for (auto I = ItAbs; I != Abs.end(); I++)
+ {
+ RelativePath = RelativePath / *I;
+ }
+ return RelativePath;
+ }
-///////////////////////////////////////////////////////////////////////////
+ void CleanDirectory(WorkerThreadPool& WorkerPool,
+ std::atomic<bool>& AbortFlag,
+ std::atomic<bool>& PauseFlag,
+ const std::filesystem::path& Path)
+ {
+ CleanDirectory(WorkerPool, AbortFlag, PauseFlag, Path, std::vector<std::string>{}, {}, 0);
+ }
-constexpr std::string_view FileHydratorPrefix = "file://";
+ class StorageBase
+ {
+ public:
+ virtual ~StorageBase() {}
+
+ virtual void Configure(std::string_view ModuleId,
+ const std::filesystem::path& TempDir,
+ std::string_view TargetSpecification,
+ const CbObject& Options) = 0;
+ virtual void SaveMetadata(const CbObject& Data) = 0;
+ virtual CbObject LoadMetadata() = 0;
+ virtual CbObject GetSettings() = 0;
+ virtual void ParseSettings(const CbObjectView& Settings) = 0;
+ virtual std::vector<IoHash> List() = 0;
+ virtual void Put(ParallelWork& Work,
+ WorkerThreadPool& WorkerPool,
+ const IoHash& Hash,
+ uint64_t Size,
+ const std::filesystem::path& SourcePath) = 0;
+ virtual void Get(ParallelWork& Work,
+ WorkerThreadPool& WorkerPool,
+ const IoHash& Hash,
+ uint64_t Size,
+ const std::filesystem::path& DestinationPath) = 0;
+ virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) = 0;
+ };
-struct FileHydrator : public HydrationStrategyBase
-{
- virtual void Configure(const HydrationConfig& Config) override;
- virtual void Hydrate() override;
- virtual void Dehydrate() override;
+ constexpr std::string_view FileHydratorPrefix = "file://";
+ constexpr std::string_view FileHydratorType = "file";
-private:
- HydrationConfig m_Config;
- std::filesystem::path m_StorageModuleRootDir;
-};
+ constexpr std::string_view S3HydratorPrefix = "s3://";
+ constexpr std::string_view S3HydratorType = "s3";
-void
-FileHydrator::Configure(const HydrationConfig& Config)
-{
- m_Config = Config;
+ class FileStorage : public StorageBase
+ {
+ public:
+ FileStorage() {}
+ virtual void Configure(std::string_view ModuleId,
+ const std::filesystem::path& TempDir,
+ std::string_view TargetSpecification,
+ const CbObject& Options)
+ {
+ ZEN_UNUSED(TempDir);
+ if (!TargetSpecification.empty())
+ {
+ m_StoragePath = Utf8ToWide(TargetSpecification.substr(FileHydratorPrefix.length()));
+ if (m_StoragePath.empty())
+ {
+ throw zen::runtime_error("Hydration config 'file' type requires a directory path");
+ }
+ }
+ else
+ {
+ CbObjectView Settings = Options["settings"].AsObjectView();
+ std::string_view Path = Settings["path"].AsString();
+ if (Path.empty())
+ {
+ throw zen::runtime_error("Hydration config 'file' type requires 'settings.path'");
+ }
+ m_StoragePath = Utf8ToWide(std::string(Path));
+ }
+ m_StoragePath = m_StoragePath / ModuleId;
+ MakeSafeAbsolutePathInPlace(m_StoragePath);
- std::filesystem::path ConfigPath(Utf8ToWide(m_Config.TargetSpecification.substr(FileHydratorPrefix.length())));
- MakeSafeAbsolutePathInPlace(ConfigPath);
+ m_StatePathName = m_StoragePath / "current-state.cbo";
+ m_CASPath = m_StoragePath / "cas";
+ CreateDirectories(m_CASPath);
+ }
+ virtual void SaveMetadata(const CbObject& Data)
+ {
+ BinaryWriter Output;
+ SaveCompactBinary(Output, Data);
+ WriteFile(m_StatePathName, IoBuffer(IoBuffer::Wrap, Output.GetData(), Output.GetSize()));
+ }
+ virtual CbObject LoadMetadata()
+ {
+ if (!IsFile(m_StatePathName))
+ {
+ return {};
+ }
+ FileContents Content = ReadFile(m_StatePathName);
+ if (Content.ErrorCode)
+ {
+ ThrowSystemError(Content.ErrorCode.value(), "Failed to read state file");
+ }
+ IoBuffer Payload = Content.Flatten();
+ CbValidateError Error;
+ CbObject Result = ValidateAndReadCompactBinaryObject(std::move(Payload), Error);
+ if (Error != CbValidateError::None)
+ {
+ throw std::runtime_error(fmt::format("Failed to read {} state file. Reason: {}", m_StatePathName, ToString(Error)));
+ }
+ return Result;
+ }
- if (!std::filesystem::exists(ConfigPath))
- {
- throw std::invalid_argument(fmt::format("Target does not exist: '{}'", ConfigPath.string()));
- }
+ virtual CbObject GetSettings() override { return {}; }
+ virtual void ParseSettings(const CbObjectView& Settings) { ZEN_UNUSED(Settings); }
- m_StorageModuleRootDir = ConfigPath / m_Config.ModuleId;
+ virtual std::vector<IoHash> List()
+ {
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_CASPath, DirectoryContentFlags::IncludeFiles, DirContent);
+ std::vector<IoHash> Result;
+ Result.reserve(DirContent.Files.size());
+ for (const std::filesystem::path& Path : DirContent.Files)
+ {
+ IoHash Hash;
+ if (IoHash::TryParse(Path.filename().string(), Hash))
+ {
+ Result.push_back(Hash);
+ }
+ }
+ return Result;
+ }
- CreateDirectories(m_StorageModuleRootDir);
-}
+ virtual void Put(ParallelWork& Work,
+ WorkerThreadPool& WorkerPool,
+ const IoHash& Hash,
+ uint64_t Size,
+ const std::filesystem::path& SourcePath)
+ {
+ ZEN_UNUSED(Size);
+ Work.ScheduleWork(WorkerPool,
+ [this, Hash = IoHash(Hash), SourcePath = std::filesystem::path(SourcePath)](std::atomic<bool>& AbortFlag) {
+ if (!AbortFlag.load())
+ {
+ CopyFile(SourcePath, m_CASPath / fmt::format("{}", Hash), CopyFileOptions{.EnableClone = true});
+ }
+ });
+ }
-void
-FileHydrator::Hydrate()
-{
- ZEN_INFO("Hydrating state from '{}' to '{}'", m_StorageModuleRootDir, m_Config.ServerStateDir);
+ virtual void Get(ParallelWork& Work,
+ WorkerThreadPool& WorkerPool,
+ const IoHash& Hash,
+ uint64_t Size,
+ const std::filesystem::path& DestinationPath)
+ {
+ ZEN_UNUSED(Size);
+ Work.ScheduleWork(
+ WorkerPool,
+ [this, Hash = IoHash(Hash), DestinationPath = std::filesystem::path(DestinationPath)](std::atomic<bool>& AbortFlag) {
+ if (!AbortFlag.load())
+ {
+ CopyFile(m_CASPath / fmt::format("{}", Hash), DestinationPath, CopyFileOptions{.EnableClone = true});
+ }
+ });
+ }
- // Ensure target is clean
- ZEN_DEBUG("Wiping server state at '{}'", m_Config.ServerStateDir);
- const bool ForceRemoveReadOnlyFiles = true;
- CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
+ virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) override
+ {
+ ZEN_UNUSED(Work);
+ ZEN_UNUSED(WorkerPool);
+ DeleteDirectories(m_StoragePath);
+ }
- bool WipeServerState = false;
+ private:
+ std::filesystem::path m_StoragePath;
+ std::filesystem::path m_StatePathName;
+ std::filesystem::path m_CASPath;
+ };
- try
+ class S3Storage : public StorageBase
{
- ZEN_DEBUG("Copying '{}' to '{}'", m_StorageModuleRootDir, m_Config.ServerStateDir);
- CopyTree(m_StorageModuleRootDir, m_Config.ServerStateDir, {.EnableClone = true});
- }
- catch (std::exception& Ex)
- {
- ZEN_WARN("Copy failed: {}. Will wipe any partially copied state from '{}'", Ex.what(), m_Config.ServerStateDir);
+ public:
+ S3Storage() {}
- // We don't do the clean right here to avoid potentially running into double-throws
- WipeServerState = true;
- }
+ virtual void Configure(std::string_view ModuleId,
+ const std::filesystem::path& TempDir,
+ std::string_view TargetSpecification,
+ const CbObject& Options)
+ {
+ m_Options = Options;
- if (WipeServerState)
- {
- ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir);
- CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
- }
-}
+ CbObjectView Settings = m_Options["settings"].AsObjectView();
+ std::string_view Spec;
+ if (!TargetSpecification.empty())
+ {
+ Spec = TargetSpecification;
+ Spec.remove_prefix(S3HydratorPrefix.size());
+ }
+ else
+ {
+ std::string_view Uri = Settings["uri"].AsString();
+ if (Uri.empty())
+ {
+ throw zen::runtime_error("Incremental S3 hydration config requires 'settings.uri'");
+ }
+ Spec = Uri;
+ Spec.remove_prefix(S3HydratorPrefix.size());
+ }
-void
-FileHydrator::Dehydrate()
-{
- ZEN_INFO("Dehydrating state from '{}' to '{}'", m_Config.ServerStateDir, m_StorageModuleRootDir);
+ size_t SlashPos = Spec.find('/');
+ std::string UserPrefix = SlashPos != std::string_view::npos ? std::string(Spec.substr(SlashPos + 1)) : std::string{};
+ m_Bucket = std::string(SlashPos != std::string_view::npos ? Spec.substr(0, SlashPos) : Spec);
+ m_KeyPrefix = UserPrefix.empty() ? std::string(ModuleId) : UserPrefix + "/" + std::string(ModuleId);
- const std::filesystem::path TargetDir = m_StorageModuleRootDir;
+ ZEN_ASSERT(!m_Bucket.empty());
- // Ensure target is clean. This could be replaced with an atomic copy at a later date
- // (i.e copy into a temporary directory name and rename it once complete)
+ std::string Region = std::string(Settings["region"].AsString());
+ if (Region.empty())
+ {
+ Region = GetEnvVariable("AWS_DEFAULT_REGION");
+ }
+ if (Region.empty())
+ {
+ Region = GetEnvVariable("AWS_REGION");
+ }
+ if (Region.empty())
+ {
+ Region = "us-east-1";
+ }
+ m_Region = std::move(Region);
+
+ std::string AccessKeyId = GetEnvVariable("AWS_ACCESS_KEY_ID");
+ if (AccessKeyId.empty())
+ {
+ m_CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider({}));
+ }
+ else
+ {
+ m_Credentials.AccessKeyId = std::move(AccessKeyId);
+ m_Credentials.SecretAccessKey = GetEnvVariable("AWS_SECRET_ACCESS_KEY");
+ m_Credentials.SessionToken = GetEnvVariable("AWS_SESSION_TOKEN");
+ }
+ m_TempDir = TempDir;
+ m_Client = CreateS3Client();
+ }
- ZEN_DEBUG("Cleaning storage root '{}'", TargetDir);
- const bool ForceRemoveReadOnlyFiles = true;
- CleanDirectory(TargetDir, ForceRemoveReadOnlyFiles);
+ virtual void SaveMetadata(const CbObject& Data)
+ {
+ S3Client& Client = *m_Client;
+ BinaryWriter Output;
+ SaveCompactBinary(Output, Data);
+ IoBuffer Payload(IoBuffer::Clone, Output.GetData(), Output.GetSize());
+
+ std::string Key = m_KeyPrefix + "/incremental-state.cbo";
+ S3Result Result = Client.PutObject(Key, std::move(Payload));
+ if (!Result.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to save incremental metadata to '{}': {}", Key, Result.Error);
+ }
+ }
- bool CopySuccess = true;
+ virtual CbObject LoadMetadata()
+ {
+ S3Client& Client = *m_Client;
+ std::string Key = m_KeyPrefix + "/incremental-state.cbo";
+ S3GetObjectResult Result = Client.GetObject(Key);
+ if (!Result.IsSuccess())
+ {
+ if (Result.Error == S3GetObjectResult::NotFoundErrorText)
+ {
+ return {};
+ }
+ throw zen::runtime_error("Failed to load incremental metadata from '{}': {}", Key, Result.Error);
+ }
- try
- {
- ZEN_DEBUG("Copying '{}' to '{}'", m_Config.ServerStateDir, TargetDir);
- CopyTree(m_Config.ServerStateDir, TargetDir, {.EnableClone = true});
- }
- catch (std::exception& Ex)
- {
- ZEN_WARN("Copy failed: {}. Will wipe any partially copied state from '{}'", Ex.what(), m_StorageModuleRootDir);
+ CbValidateError Error;
+ CbObject Meta = ValidateAndReadCompactBinaryObject(std::move(Result.Content), Error);
+ if (Error != CbValidateError::None)
+ {
+ throw zen::runtime_error("Failed to parse incremental metadata from '{}': {}", Key, ToString(Error));
+ }
+ return Meta;
+ }
- // We don't do the clean right here to avoid potentially running into double-throws
- CopySuccess = false;
- }
+ virtual CbObject GetSettings() override
+ {
+ CbObjectWriter Writer;
+ Writer << "MultipartChunkSize" << m_MultipartChunkSize;
+ return Writer.Save();
+ }
- if (!CopySuccess)
- {
- ZEN_DEBUG("Removing partially copied state from '{}'", TargetDir);
- CleanDirectory(TargetDir, ForceRemoveReadOnlyFiles);
- }
+ virtual void ParseSettings(const CbObjectView& Settings)
+ {
+ m_MultipartChunkSize = Settings["MultipartChunkSize"].AsUInt64(DefaultMultipartChunkSize);
+ }
- ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir);
- CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
-}
+ virtual std::vector<IoHash> List()
+ {
+ S3Client& Client = *m_Client;
+ std::string Prefix = m_KeyPrefix + "/cas/";
+ S3ListObjectsResult Result = Client.ListObjects(Prefix);
+ if (!Result.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to list S3 objects under '{}': {}", Prefix, Result.Error);
+ }
-///////////////////////////////////////////////////////////////////////////
+ std::vector<IoHash> Hashes;
+ Hashes.reserve(Result.Objects.size());
+ for (const S3ObjectInfo& Obj : Result.Objects)
+ {
+ size_t LastSlash = Obj.Key.rfind('/');
+ if (LastSlash == std::string::npos)
+ {
+ continue;
+ }
+ IoHash Hash;
+ if (IoHash::TryParse(Obj.Key.substr(LastSlash + 1), Hash))
+ {
+ Hashes.push_back(Hash);
+ }
+ }
+ return Hashes;
+ }
-constexpr std::string_view S3HydratorPrefix = "s3://";
+ virtual void Put(ParallelWork& Work,
+ WorkerThreadPool& WorkerPool,
+ const IoHash& Hash,
+ uint64_t Size,
+ const std::filesystem::path& SourcePath)
+ {
+ Work.ScheduleWork(
+ WorkerPool,
+ [this, Hash = IoHash(Hash), Size, SourcePath = std::filesystem::path(SourcePath)](std::atomic<bool>& AbortFlag) {
+ if (AbortFlag.load())
+ {
+ return;
+ }
+ S3Client& Client = *m_Client;
+ std::string Key = m_KeyPrefix + "/cas/" + fmt::format("{}", Hash);
-struct S3Hydrator : public HydrationStrategyBase
-{
- void Configure(const HydrationConfig& Config) override;
- void Dehydrate() override;
- void Hydrate() override;
+ if (Size >= (m_MultipartChunkSize + (m_MultipartChunkSize / 4)))
+ {
+ BasicFile File(SourcePath, BasicFile::Mode::kRead);
+ S3Result Result = Client.PutObjectMultipart(
+ Key,
+ Size,
+ [&File](uint64_t Offset, uint64_t ChunkSize) { return File.ReadRange(Offset, ChunkSize); },
+ m_MultipartChunkSize);
+ if (!Result.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, Result.Error);
+ }
+ }
+ else
+ {
+ BasicFile File(SourcePath, BasicFile::Mode::kRead);
+ S3Result Result = Client.PutObject(Key, File.ReadAll());
+ if (!Result.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, Result.Error);
+ }
+ }
+ });
+ }
-private:
- S3Client CreateS3Client() const;
- std::string BuildTimestampFolderName() const;
- std::string MakeObjectKey(std::string_view FolderName, const std::filesystem::path& RelPath) const;
-
- HydrationConfig m_Config;
- std::string m_Bucket;
- std::string m_KeyPrefix; // "<user-prefix>/<ModuleId>" or just "<ModuleId>" - no trailing slash
- std::string m_Region;
- SigV4Credentials m_Credentials;
- Ref<ImdsCredentialProvider> m_CredentialProvider;
-};
+ virtual void Get(ParallelWork& Work,
+ WorkerThreadPool& WorkerPool,
+ const IoHash& Hash,
+ uint64_t Size,
+ const std::filesystem::path& DestinationPath)
+ {
+ std::string Key = m_KeyPrefix + "/cas/" + fmt::format("{}", Hash);
-void
-S3Hydrator::Configure(const HydrationConfig& Config)
-{
- m_Config = Config;
+ if (Size >= (m_MultipartChunkSize + (m_MultipartChunkSize / 4)))
+ {
+ class WorkData
+ {
+ public:
+ WorkData(const std::filesystem::path& DestPath, uint64_t Size) : m_DestFile(DestPath, BasicFile::Mode::kTruncate)
+ {
+ PrepareFileForScatteredWrite(m_DestFile.Handle(), Size);
+ }
+ ~WorkData() { m_DestFile.Flush(); }
+ void Write(const void* Data, uint64_t Size, uint64_t Offset) { m_DestFile.Write(Data, Size, Offset); }
- std::string_view Spec = m_Config.TargetSpecification;
- Spec.remove_prefix(S3HydratorPrefix.size());
+ private:
+ BasicFile m_DestFile;
+ };
- size_t SlashPos = Spec.find('/');
- std::string UserPrefix = SlashPos != std::string_view::npos ? std::string(Spec.substr(SlashPos + 1)) : std::string{};
- m_Bucket = std::string(SlashPos != std::string_view::npos ? Spec.substr(0, SlashPos) : Spec);
- m_KeyPrefix = UserPrefix.empty() ? m_Config.ModuleId : UserPrefix + "/" + m_Config.ModuleId;
+ std::shared_ptr<WorkData> Data = std::make_shared<WorkData>(DestinationPath, Size);
- ZEN_ASSERT(!m_Bucket.empty());
+ uint64_t Offset = 0;
+ while (Offset < Size)
+ {
+ uint64_t ChunkSize = std::min<uint64_t>(m_MultipartChunkSize, Size - Offset);
+
+ Work.ScheduleWork(WorkerPool, [this, Key = Key, Offset, ChunkSize, Data](std::atomic<bool>& AbortFlag) {
+ if (AbortFlag)
+ {
+ return;
+ }
+ S3GetObjectResult Chunk = m_Client->GetObjectRange(Key, Offset, ChunkSize);
+ if (!Chunk.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to download '{}' bytes [{}-{}] from S3: {}",
+ Key,
+ Offset,
+ Offset + ChunkSize - 1,
+ Chunk.Error);
+ }
+
+ Data->Write(Chunk.Content.GetData(), Chunk.Content.GetSize(), Offset);
+ });
+ Offset += ChunkSize;
+ }
+ }
+ else
+ {
+ Work.ScheduleWork(
+ WorkerPool,
+ [this, Key = Key, DestinationPath = std::filesystem::path(DestinationPath)](std::atomic<bool>& AbortFlag) {
+ if (AbortFlag)
+ {
+ return;
+ }
+ S3GetObjectResult Chunk = m_Client->GetObject(Key, m_TempDir);
+ if (!Chunk.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to download '{}' from S3: {}", Key, Chunk.Error);
+ }
+
+ if (IoBufferFileReference FileRef; Chunk.Content.GetFileReference(FileRef))
+ {
+ std::error_code Ec;
+ std::filesystem::path ChunkPath = PathFromHandle(FileRef.FileHandle, Ec);
+ if (Ec)
+ {
+ WriteFile(DestinationPath, Chunk.Content);
+ }
+ else
+ {
+ Chunk.Content.SetDeleteOnClose(false);
+ Chunk.Content = {};
+ RenameFile(ChunkPath, DestinationPath, Ec);
+ if (Ec)
+ {
+ Chunk.Content = IoBufferBuilder::MakeFromFile(ChunkPath);
+ Chunk.Content.SetDeleteOnClose(true);
+ WriteFile(DestinationPath, Chunk.Content);
+ }
+ }
+ }
+ else
+ {
+ WriteFile(DestinationPath, Chunk.Content);
+ }
+ });
+ }
+ }
- std::string Region = GetEnvVariable("AWS_DEFAULT_REGION");
- if (Region.empty())
- {
- Region = GetEnvVariable("AWS_REGION");
- }
- if (Region.empty())
- {
- Region = "us-east-1";
- }
- m_Region = std::move(Region);
+ virtual void Delete(ParallelWork& Work, WorkerThreadPool& WorkerPool) override
+ {
+ std::string Prefix = m_KeyPrefix + "/";
+ S3ListObjectsResult ListResult = m_Client->ListObjects(Prefix);
+ if (!ListResult.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to list S3 objects for deletion under '{}': {}", Prefix, ListResult.Error);
+ }
+ for (const S3ObjectInfo& Obj : ListResult.Objects)
+ {
+ Work.ScheduleWork(WorkerPool, [this, Key = Obj.Key](std::atomic<bool>& AbortFlag) {
+ if (AbortFlag.load())
+ {
+ return;
+ }
+ S3Result DelResult = m_Client->DeleteObject(Key);
+ if (!DelResult.IsSuccess())
+ {
+ throw zen::runtime_error("Failed to delete S3 object '{}': {}", Key, DelResult.Error);
+ }
+ });
+ }
+ }
- std::string AccessKeyId = GetEnvVariable("AWS_ACCESS_KEY_ID");
- if (AccessKeyId.empty())
- {
- m_CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider({}));
- }
- else
- {
- m_Credentials.AccessKeyId = std::move(AccessKeyId);
- m_Credentials.SecretAccessKey = GetEnvVariable("AWS_SECRET_ACCESS_KEY");
- m_Credentials.SessionToken = GetEnvVariable("AWS_SESSION_TOKEN");
- }
-}
+ private:
+ std::unique_ptr<S3Client> CreateS3Client() const
+ {
+ S3ClientOptions Options;
+ Options.BucketName = m_Bucket;
+ Options.Region = m_Region;
-S3Client
-S3Hydrator::CreateS3Client() const
+ CbObjectView Settings = m_Options["settings"].AsObjectView();
+ std::string_view Endpoint = Settings["endpoint"].AsString();
+ if (!Endpoint.empty())
+ {
+ Options.Endpoint = std::string(Endpoint);
+ Options.PathStyle = Settings["path-style"].AsBool();
+ }
+
+ if (m_CredentialProvider)
+ {
+ Options.CredentialProvider = m_CredentialProvider;
+ }
+ else
+ {
+ Options.Credentials = m_Credentials;
+ }
+
+ Options.HttpSettings.MaximumInMemoryDownloadSize = 16u * 1024u;
+
+ return std::make_unique<S3Client>(Options);
+ }
+
+ static constexpr uint64_t DefaultMultipartChunkSize = 32u * 1024u * 1024u;
+
+ std::string m_KeyPrefix;
+ CbObject m_Options;
+ std::string m_Bucket;
+ std::string m_Region;
+ SigV4Credentials m_Credentials;
+ Ref<ImdsCredentialProvider> m_CredentialProvider;
+ std::unique_ptr<S3Client> m_Client;
+ std::filesystem::path m_TempDir;
+ uint64_t m_MultipartChunkSize = DefaultMultipartChunkSize;
+ };
+
+} // namespace hydration_impl
+
+using namespace hydration_impl;
+
+///////////////////////////////////////////////////////////////////////////
+
+class IncrementalHydrator : public HydrationStrategyBase
{
- S3ClientOptions Options;
- Options.BucketName = m_Bucket;
- Options.Region = m_Region;
+public:
+ IncrementalHydrator(std::unique_ptr<StorageBase>&& Storage);
+ virtual ~IncrementalHydrator() override;
+ virtual void Configure(const HydrationConfig& Config) override;
+ virtual void Dehydrate(const CbObject& CachedState) override;
+ virtual CbObject Hydrate() override;
+ virtual void Obliterate() override;
- if (!m_Config.S3Endpoint.empty())
+private:
+ struct Entry
{
- Options.Endpoint = m_Config.S3Endpoint;
- Options.PathStyle = m_Config.S3PathStyle;
- }
+ std::filesystem::path RelativePath;
+ uint64_t Size;
+ uint64_t ModTick;
+ IoHash Hash;
+ };
- if (m_CredentialProvider)
- {
- Options.CredentialProvider = m_CredentialProvider;
- }
- else
- {
- Options.Credentials = m_Credentials;
- }
+ std::unique_ptr<StorageBase> m_Storage;
+ HydrationConfig m_Config;
+ WorkerThreadPool m_FallbackWorkPool;
+ std::atomic<bool> m_FallbackAbortFlag{false};
+ std::atomic<bool> m_FallbackPauseFlag{false};
+ HydrationConfig::ThreadingOptions m_Threading{.WorkerPool = &m_FallbackWorkPool,
+ .AbortFlag = &m_FallbackAbortFlag,
+ .PauseFlag = &m_FallbackPauseFlag};
+};
- return S3Client(Options);
+IncrementalHydrator::IncrementalHydrator(std::unique_ptr<StorageBase>&& Storage) : m_Storage(std::move(Storage)), m_FallbackWorkPool(0)
+{
}
-std::string
-S3Hydrator::BuildTimestampFolderName() const
+IncrementalHydrator::~IncrementalHydrator()
{
- UtcTime Now = UtcTime::Now();
- return fmt::format("{:04d}{:02d}{:02d}-{:02d}{:02d}{:02d}-{:03d}",
- Now.Tm.tm_year + 1900,
- Now.Tm.tm_mon + 1,
- Now.Tm.tm_mday,
- Now.Tm.tm_hour,
- Now.Tm.tm_min,
- Now.Tm.tm_sec,
- Now.Ms);
+ m_Storage.reset();
}
-std::string
-S3Hydrator::MakeObjectKey(std::string_view FolderName, const std::filesystem::path& RelPath) const
+void
+IncrementalHydrator::Configure(const HydrationConfig& Config)
{
- return m_KeyPrefix + "/" + std::string(FolderName) + "/" + RelPath.generic_string();
+ m_Config = Config;
+ m_Storage->Configure(Config.ModuleId, Config.TempDir, Config.TargetSpecification, Config.Options);
+ if (Config.Threading)
+ {
+ m_Threading = *Config.Threading;
+ }
}
void
-S3Hydrator::Dehydrate()
+IncrementalHydrator::Dehydrate(const CbObject& CachedState)
{
- ZEN_INFO("Dehydrating state from '{}' to s3://{}/{}", m_Config.ServerStateDir, m_Bucket, m_KeyPrefix);
+ Stopwatch Timer;
+ const std::filesystem::path ServerStateDir = MakeSafeAbsolutePath(m_Config.ServerStateDir);
try
{
- S3Client Client = CreateS3Client();
- std::string FolderName = BuildTimestampFolderName();
- uint64_t TotalBytes = 0;
- uint32_t FileCount = 0;
- std::chrono::steady_clock::time_point UploadStart = std::chrono::steady_clock::now();
+ std::unordered_map<std::string, size_t> StateEntryLookup;
+ std::vector<Entry> StateEntries;
+ for (CbFieldView FieldView : CachedState["Files"].AsArrayView())
+ {
+ CbObjectView EntryView = FieldView.AsObjectView();
+ std::filesystem::path RelativePath(EntryView["Path"].AsString());
+ uint64_t Size = EntryView["Size"].AsUInt64();
+ uint64_t ModTick = EntryView["ModTick"].AsUInt64();
+ IoHash Hash = EntryView["Hash"].AsHash();
+
+ StateEntryLookup.insert_or_assign(RelativePath.generic_string(), StateEntries.size());
+ StateEntries.push_back(Entry{.RelativePath = RelativePath, .Size = Size, .ModTick = ModTick, .Hash = Hash});
+ }
DirectoryContent DirContent;
- GetDirectoryContent(m_Config.ServerStateDir, DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive, DirContent);
+ GetDirectoryContent(*m_Threading.WorkerPool,
+ ServerStateDir,
+ DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive |
+ DirectoryContentFlags::IncludeFileSizes | DirectoryContentFlags::IncludeModificationTick,
+ DirContent);
+
+ ZEN_INFO("Dehydrating module '{}' from folder '{}'. {} ({}) files",
+ m_Config.ModuleId,
+ m_Config.ServerStateDir,
+ DirContent.Files.size(),
+ NiceBytes(std::accumulate(DirContent.FileSizes.begin(), DirContent.FileSizes.end(), uint64_t(0))));
+
+ std::vector<Entry> Entries;
+ Entries.resize(DirContent.Files.size());
+
+ uint64_t TotalBytes = 0;
+ uint64_t TotalFiles = 0;
+ uint64_t HashedFiles = 0;
+ uint64_t HashedBytes = 0;
+
+ std::unordered_set<IoHash> ExistsLookup;
- for (const std::filesystem::path& AbsPath : DirContent.Files)
{
- std::filesystem::path RelPath = AbsPath.lexically_relative(m_Config.ServerStateDir);
- if (RelPath.empty() || *RelPath.begin() == "..")
+ ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+
+ for (size_t FileIndex = 0; FileIndex < DirContent.Files.size(); FileIndex++)
{
- throw zen::runtime_error(
- "lexically_relative produced a '..'-escape path for '{}' relative to '{}' - "
- "path form mismatch (e.g. \\\\?\\ prefix on one but not the other)",
- AbsPath.string(),
- m_Config.ServerStateDir.string());
- }
- std::string Key = MakeObjectKey(FolderName, RelPath);
+ const std::filesystem::path AbsPath = MakeSafeAbsolutePath(DirContent.Files[FileIndex]);
+ if (AbsPath.filename() == "reserve.gc")
+ {
+ continue;
+ }
+ const std::filesystem::path RelativePath = FastRelativePath(ServerStateDir, DirContent.Files[FileIndex]);
+ if (*RelativePath.begin() == ".sentry-native")
+ {
+ continue;
+ }
+ if (RelativePath == ".lock")
+ {
+ continue;
+ }
- BasicFile File(AbsPath, BasicFile::Mode::kRead);
- uint64_t FileSize = File.FileSize();
+ Entry& CurrentEntry = Entries[TotalFiles];
+ CurrentEntry.RelativePath = RelativePath;
+ CurrentEntry.Size = DirContent.FileSizes[FileIndex];
+ CurrentEntry.ModTick = DirContent.FileModificationTicks[FileIndex];
- S3Result UploadResult =
- Client.PutObjectMultipart(Key, FileSize, [&File](uint64_t Offset, uint64_t Size) { return File.ReadRange(Offset, Size); });
- if (!UploadResult.IsSuccess())
- {
- throw zen::runtime_error("Failed to upload '{}' to S3: {}", Key, UploadResult.Error);
+ bool FoundHash = false;
+ if (auto KnownIt = StateEntryLookup.find(CurrentEntry.RelativePath.generic_string()); KnownIt != StateEntryLookup.end())
+ {
+ const Entry& StateEntry = StateEntries[KnownIt->second];
+ if (StateEntry.Size == CurrentEntry.Size && StateEntry.ModTick == CurrentEntry.ModTick)
+ {
+ CurrentEntry.Hash = StateEntry.Hash;
+ FoundHash = true;
+ }
+ }
+
+ if (!FoundHash)
+ {
+ Work.ScheduleWork(*m_Threading.WorkerPool, [AbsPath, EntryIndex = TotalFiles, &Entries](std::atomic<bool>& AbortFlag) {
+ if (AbortFlag)
+ {
+ return;
+ }
+
+ Entry& CurrentEntry = Entries[EntryIndex];
+
+ bool FoundHash = false;
+ if (AbsPath.extension().empty())
+ {
+ auto It = CurrentEntry.RelativePath.begin();
+ if (It != CurrentEntry.RelativePath.end() && It->filename().string().ends_with("cas"))
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer(IoBufferBuilder::MakeFromFile(AbsPath)),
+ RawHash,
+ RawSize);
+ if (Compressed)
+ {
+ // We compose a meta-hash since taking the RawHash might collide with an existing
+ // non-compressed file with the same content The collision is unlikely except if the
+ // compressed data is zero bytes causing RawHash to be the same as an empty file.
+ IoHashStream Hasher;
+ Hasher.Append(RawHash.Hash, sizeof(RawHash.Hash));
+ Hasher.Append(&CurrentEntry.Size, sizeof(CurrentEntry.Size));
+ CurrentEntry.Hash = Hasher.GetHash();
+ FoundHash = true;
+ }
+ }
+ }
+
+ if (!FoundHash)
+ {
+ CurrentEntry.Hash = IoHash::HashBuffer(IoBufferBuilder::MakeFromFile(AbsPath));
+ }
+ });
+ HashedFiles++;
+ HashedBytes += CurrentEntry.Size;
+ }
+ TotalFiles++;
+ TotalBytes += CurrentEntry.Size;
}
- TotalBytes += FileSize;
- ++FileCount;
+ std::vector<IoHash> ExistingEntries = m_Storage->List();
+ ExistsLookup.insert(ExistingEntries.begin(), ExistingEntries.end());
+
+ Work.Wait();
+
+ Entries.resize(TotalFiles);
}
- // Write current-state.json
- int64_t UploadDurationMs =
- std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - UploadStart).count();
-
- UtcTime Now = UtcTime::Now();
- std::string UploadTimeUtc = fmt::format("{:04d}-{:02d}-{:02d}T{:02d}:{:02d}:{:02d}.{:03d}Z",
- Now.Tm.tm_year + 1900,
- Now.Tm.tm_mon + 1,
- Now.Tm.tm_mday,
- Now.Tm.tm_hour,
- Now.Tm.tm_min,
- Now.Tm.tm_sec,
- Now.Ms);
-
- CbObjectWriter Meta;
- Meta << "FolderName" << FolderName;
- Meta << "ModuleId" << m_Config.ModuleId;
- Meta << "HostName" << GetMachineName();
- Meta << "UploadTimeUtc" << UploadTimeUtc;
- Meta << "UploadDurationMs" << UploadDurationMs;
- Meta << "TotalSizeBytes" << TotalBytes;
- Meta << "FileCount" << FileCount;
-
- ExtendableStringBuilder<1024> JsonBuilder;
- Meta.Save().ToJson(JsonBuilder);
-
- std::string MetaKey = m_KeyPrefix + "/current-state.json";
- std::string_view JsonText = JsonBuilder.ToView();
- IoBuffer MetaBuf(IoBuffer::Clone, JsonText.data(), JsonText.size());
- S3Result MetaUploadResult = Client.PutObject(MetaKey, std::move(MetaBuf));
- if (!MetaUploadResult.IsSuccess())
+ uint64_t UploadedFiles = 0;
+ uint64_t UploadedBytes = 0;
{
- throw zen::runtime_error("Failed to write current-state.json to '{}': {}", MetaKey, MetaUploadResult.Error);
+ ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::DisableBacklog);
+
+ for (const Entry& CurrentEntry : Entries)
+ {
+ if (!ExistsLookup.contains(CurrentEntry.Hash))
+ {
+ m_Storage->Put(Work,
+ *m_Threading.WorkerPool,
+ CurrentEntry.Hash,
+ CurrentEntry.Size,
+ MakeSafeAbsolutePath(ServerStateDir / CurrentEntry.RelativePath));
+ UploadedFiles++;
+ UploadedBytes += CurrentEntry.Size;
+ }
+ }
+
+ Work.Wait();
+ uint64_t UploadTimeMs = Timer.GetElapsedTimeMs();
+
+ UtcTime Now = UtcTime::Now();
+ std::string UploadTimeUtc = fmt::format("{:04d}-{:02d}-{:02d}T{:02d}:{:02d}:{:02d}.{:03d}Z",
+ Now.Tm.tm_year + 1900,
+ Now.Tm.tm_mon + 1,
+ Now.Tm.tm_mday,
+ Now.Tm.tm_hour,
+ Now.Tm.tm_min,
+ Now.Tm.tm_sec,
+ Now.Ms);
+
+ CbObjectWriter Meta;
+ Meta << "SourceFolder" << ServerStateDir.generic_string();
+ Meta << "ModuleId" << m_Config.ModuleId;
+ Meta << "HostName" << GetMachineName();
+ Meta << "UploadTimeUtc" << UploadTimeUtc;
+ Meta << "UploadDurationMs" << UploadTimeMs;
+ Meta << "TotalSizeBytes" << TotalBytes;
+ Meta << "StorageSettings" << m_Storage->GetSettings();
+
+ Meta.BeginArray("Files");
+ for (const Entry& CurrentEntry : Entries)
+ {
+ Meta.BeginObject();
+ {
+ Meta << "Path" << CurrentEntry.RelativePath.generic_string();
+ Meta << "Size" << CurrentEntry.Size;
+ Meta << "ModTick" << CurrentEntry.ModTick;
+ Meta << "Hash" << CurrentEntry.Hash;
+ }
+ Meta.EndObject();
+ }
+ Meta.EndArray();
+
+ m_Storage->SaveMetadata(Meta.Save());
}
- ZEN_INFO("Dehydration complete: {} files, {} bytes, {} ms", FileCount, TotalBytes, UploadDurationMs);
+ ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir);
+
+ ZEN_INFO("Dehydration of module '{}' completed from folder '{}'. Hashed {} ({}). Uploaded {} ({}). Total {} ({}) in {}",
+ m_Config.ModuleId,
+ m_Config.ServerStateDir,
+ HashedFiles,
+ NiceBytes(HashedBytes),
+ UploadedFiles,
+ NiceBytes(UploadedBytes),
+ TotalFiles,
+ NiceBytes(TotalBytes),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
}
- catch (std::exception& Ex)
+ catch (const std::exception& Ex)
{
- // Any in-progress multipart upload has already been aborted by PutObjectMultipart.
- // current-state.json is only written on success, so the previous S3 state remains valid.
- ZEN_WARN("S3 dehydration failed: {}. S3 state not updated.", Ex.what());
+ ZEN_WARN("Dehydration of module '{}' failed: {}. Leaving server state '{}'", m_Config.ModuleId, Ex.what(), m_Config.ServerStateDir);
}
}
-void
-S3Hydrator::Hydrate()
+CbObject
+IncrementalHydrator::Hydrate()
{
- ZEN_INFO("Hydrating state from s3://{}/{} to '{}'", m_Bucket, m_KeyPrefix, m_Config.ServerStateDir);
-
- const bool ForceRemoveReadOnlyFiles = true;
-
- // Clean temp dir before starting in case of leftover state from a previous failed hydration
- ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir);
- CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles);
-
- bool WipeServerState = false;
+ Stopwatch Timer;
+ const std::filesystem::path ServerStateDir = MakeSafeAbsolutePath(m_Config.ServerStateDir);
+ const std::filesystem::path TempDir = MakeSafeAbsolutePath(m_Config.TempDir);
try
{
- S3Client Client = CreateS3Client();
- std::string MetaKey = m_KeyPrefix + "/current-state.json";
-
- S3HeadObjectResult HeadResult = Client.HeadObject(MetaKey);
- if (HeadResult.Status == HeadObjectResult::NotFound)
- {
- throw zen::runtime_error("No state found in S3 at '{}'", MetaKey);
- }
- if (!HeadResult.IsSuccess())
+ CbObject Meta = m_Storage->LoadMetadata();
+ if (!Meta)
{
- throw zen::runtime_error("Failed to check for state in S3 at '{}': {}", MetaKey, HeadResult.Error);
+ ZEN_INFO("No dehydrated state for module {} found, cleaning server state: '{}'", m_Config.ModuleId, m_Config.ServerStateDir);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir);
+ return CbObject();
}
- S3GetObjectResult MetaResult = Client.GetObject(MetaKey);
- if (!MetaResult.IsSuccess())
- {
- throw zen::runtime_error("Failed to read current-state.json from '{}': {}", MetaKey, MetaResult.Error);
- }
+ std::unordered_map<std::string, size_t> EntryLookup;
+ std::vector<Entry> Entries;
+ uint64_t TotalSize = 0;
- std::string ParseError;
- json11::Json MetaJson = json11::Json::parse(std::string(MetaResult.AsText()), ParseError);
- if (!ParseError.empty())
+ for (CbFieldView FieldView : Meta["Files"])
{
- throw zen::runtime_error("Failed to parse current-state.json from '{}': {}", MetaKey, ParseError);
+ CbObjectView EntryView = FieldView.AsObjectView();
+ if (EntryView)
+ {
+ Entry NewEntry = {.RelativePath = std::filesystem::path(EntryView["Path"].AsString()),
+ .Size = EntryView["Size"].AsUInt64(),
+ .ModTick = EntryView["ModTick"].AsUInt64(),
+ .Hash = EntryView["Hash"].AsHash()};
+ TotalSize += NewEntry.Size;
+ EntryLookup.insert_or_assign(NewEntry.RelativePath.generic_string(), Entries.size());
+ Entries.emplace_back(std::move(NewEntry));
+ }
}
- std::string FolderName = MetaJson["FolderName"].string_value();
- if (FolderName.empty())
- {
- throw zen::runtime_error("current-state.json from '{}' has missing or empty FolderName", MetaKey);
- }
+ ZEN_INFO("Hydrating module '{}' to folder '{}'. {} ({}) files",
+ m_Config.ModuleId,
+ m_Config.ServerStateDir,
+ Entries.size(),
+ NiceBytes(TotalSize));
- std::string FolderPrefix = m_KeyPrefix + "/" + FolderName + "/";
- S3ListObjectsResult ListResult = Client.ListObjects(FolderPrefix);
- if (!ListResult.IsSuccess())
- {
- throw zen::runtime_error("Failed to list S3 objects under '{}': {}", FolderPrefix, ListResult.Error);
- }
+ m_Storage->ParseSettings(Meta["StorageSettings"].AsObjectView());
+
+ uint64_t DownloadedBytes = 0;
+ uint64_t DownloadedFiles = 0;
- for (const S3ObjectInfo& Obj : ListResult.Objects)
{
- if (!Obj.Key.starts_with(FolderPrefix))
- {
- ZEN_WARN("Skipping unexpected S3 key '{}' (expected prefix '{}')", Obj.Key, FolderPrefix);
- continue;
- }
+ ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
- std::string RelKey = Obj.Key.substr(FolderPrefix.size());
- if (RelKey.empty())
+ for (const Entry& CurrentEntry : Entries)
{
- continue;
+ std::filesystem::path Path = MakeSafeAbsolutePath(TempDir / CurrentEntry.RelativePath);
+ CreateDirectories(Path.parent_path());
+ m_Storage->Get(Work, *m_Threading.WorkerPool, CurrentEntry.Hash, CurrentEntry.Size, Path);
+ DownloadedBytes += CurrentEntry.Size;
+ DownloadedFiles++;
}
- std::filesystem::path DestPath = MakeSafeAbsolutePath(m_Config.TempDir / std::filesystem::path(RelKey));
- CreateDirectories(DestPath.parent_path());
- BasicFile DestFile(DestPath, BasicFile::Mode::kTruncate);
- DestFile.SetFileSize(Obj.Size);
-
- if (Obj.Size > 0)
- {
- BasicFileWriter Writer(DestFile, 64 * 1024);
-
- uint64_t Offset = 0;
- while (Offset < Obj.Size)
- {
- uint64_t ChunkSize = std::min<uint64_t>(8 * 1024 * 1024, Obj.Size - Offset);
- S3GetObjectResult Chunk = Client.GetObjectRange(Obj.Key, Offset, ChunkSize);
- if (!Chunk.IsSuccess())
- {
- throw zen::runtime_error("Failed to download '{}' bytes [{}-{}] from S3: {}",
- Obj.Key,
- Offset,
- Offset + ChunkSize - 1,
- Chunk.Error);
- }
-
- Writer.Write(Chunk.Content.GetData(), Chunk.Content.GetSize(), Offset);
- Offset += ChunkSize;
- }
-
- Writer.Flush();
- }
+ Work.Wait();
}
// Downloaded successfully - swap into ServerStateDir
- ZEN_DEBUG("Wiping server state '{}'", m_Config.ServerStateDir);
- CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
+ ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir);
// If the two paths share at least one common component they are on the same drive/volume
// and atomic renames will succeed. Otherwise fall back to a full copy.
- auto [ItTmp, ItState] =
- std::mismatch(m_Config.TempDir.begin(), m_Config.TempDir.end(), m_Config.ServerStateDir.begin(), m_Config.ServerStateDir.end());
- if (ItTmp != m_Config.TempDir.begin())
+ auto [ItTmp, ItState] = std::mismatch(TempDir.begin(), TempDir.end(), ServerStateDir.begin(), ServerStateDir.end());
+ if (ItTmp != TempDir.begin())
{
- // Fast path: atomic renames - no data copying needed
- for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(m_Config.TempDir))
+ DirectoryContent DirContent;
+ GetDirectoryContent(*m_Threading.WorkerPool,
+ TempDir,
+ DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::IncludeDirs,
+ DirContent);
+
+ for (const std::filesystem::path& AbsPath : DirContent.Directories)
{
- std::filesystem::path Dest = MakeSafeAbsolutePath(m_Config.ServerStateDir / Entry.path().filename());
- if (Entry.is_directory())
+ std::filesystem::path Dest = MakeSafeAbsolutePath(ServerStateDir / AbsPath.filename());
+ std::error_code Ec = RenameDirectoryWithRetry(AbsPath, Dest);
+ if (Ec)
{
- RenameDirectory(Entry.path(), Dest);
+ throw std::system_error(Ec, fmt::format("Failed to rename directory from '{}' to '{}'", AbsPath, Dest));
}
- else
+ }
+ for (const std::filesystem::path& AbsPath : DirContent.Files)
+ {
+ std::filesystem::path Dest = MakeSafeAbsolutePath(ServerStateDir / AbsPath.filename());
+ std::error_code Ec = RenameFileWithRetry(AbsPath, Dest);
+ if (Ec)
{
- RenameFile(Entry.path(), Dest);
+ throw std::system_error(Ec, fmt::format("Failed to rename file from '{}' to '{}'", AbsPath, Dest));
}
}
+
ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir);
- CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir);
}
else
{
// Slow path: TempDir and ServerStateDir are on different filesystems, so rename
// would fail. Copy the tree instead and clean up the temp files afterwards.
ZEN_DEBUG("TempDir and ServerStateDir are on different filesystems - using CopyTree");
- CopyTree(m_Config.TempDir, m_Config.ServerStateDir, {.EnableClone = true});
+ CopyTree(TempDir, ServerStateDir, {.EnableClone = true});
ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir);
- CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir);
+ }
+
+ // TODO: This could perhaps be done more efficently, but ok for now
+ DirectoryContent DirContent;
+ GetDirectoryContent(*m_Threading.WorkerPool,
+ ServerStateDir,
+ DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive |
+ DirectoryContentFlags::IncludeFileSizes | DirectoryContentFlags::IncludeModificationTick,
+ DirContent);
+
+ CbObjectWriter HydrateState;
+ HydrateState.BeginArray("Files");
+ for (size_t FileIndex = 0; FileIndex < DirContent.Files.size(); FileIndex++)
+ {
+ std::filesystem::path RelativePath = FastRelativePath(ServerStateDir, DirContent.Files[FileIndex]);
+
+ if (auto It = EntryLookup.find(RelativePath.generic_string()); It != EntryLookup.end())
+ {
+ HydrateState.BeginObject();
+ {
+ HydrateState << "Path" << RelativePath.generic_string();
+ HydrateState << "Size" << DirContent.FileSizes[FileIndex];
+ HydrateState << "ModTick" << DirContent.FileModificationTicks[FileIndex];
+ HydrateState << "Hash" << Entries[It->second].Hash;
+ }
+ HydrateState.EndObject();
+ }
+ else
+ {
+ ZEN_ASSERT(false);
+ }
}
+ HydrateState.EndArray();
+
+ CbObject StateObject = HydrateState.Save();
- ZEN_INFO("Hydration complete from folder '{}'", FolderName);
+ ZEN_INFO("Hydration of module '{}' complete to folder '{}'. {} ({}) files in {}",
+ m_Config.ModuleId,
+ m_Config.ServerStateDir,
+ DownloadedFiles,
+ NiceBytes(DownloadedBytes),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+
+ return StateObject;
}
- catch (std::exception& Ex)
+ catch (const std::exception& Ex)
{
- ZEN_WARN("S3 hydration failed: {}. Will wipe any partially installed state.", Ex.what());
-
- // We don't do the clean right here to avoid potentially running into double-throws
- WipeServerState = true;
+ ZEN_WARN("Hydration of module '{}' failed: {}. Cleaning server state '{}'", m_Config.ModuleId, Ex.what(), m_Config.ServerStateDir);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir);
+ ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir);
+ return {};
}
+}
- if (WipeServerState)
+void
+IncrementalHydrator::Obliterate()
+{
+ const std::filesystem::path ServerStateDir = MakeSafeAbsolutePath(m_Config.ServerStateDir);
+ const std::filesystem::path TempDir = MakeSafeAbsolutePath(m_Config.TempDir);
+
+ try
{
- ZEN_DEBUG("Cleaning server state '{}'", m_Config.ServerStateDir);
- CleanDirectory(m_Config.ServerStateDir, ForceRemoveReadOnlyFiles);
- ZEN_DEBUG("Cleaning temp dir '{}'", m_Config.TempDir);
- CleanDirectory(m_Config.TempDir, ForceRemoveReadOnlyFiles);
+ ParallelWork Work(*m_Threading.AbortFlag, *m_Threading.PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
+ m_Storage->Delete(Work, *m_Threading.WorkerPool);
+ Work.Wait();
}
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("Failed to delete backend storage for module '{}': {}. Proceeding with local cleanup.", m_Config.ModuleId, Ex.what());
+ }
+
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, ServerStateDir);
+ CleanDirectory(*m_Threading.WorkerPool, *m_Threading.AbortFlag, *m_Threading.PauseFlag, TempDir);
}
std::unique_ptr<HydrationStrategyBase>
CreateHydrator(const HydrationConfig& Config)
{
- if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0)
+ std::unique_ptr<StorageBase> Storage;
+
+ if (!Config.TargetSpecification.empty())
{
- std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<FileHydrator>();
- Hydrator->Configure(Config);
- return Hydrator;
+ if (StrCaseCompare(Config.TargetSpecification.substr(0, FileHydratorPrefix.length()), FileHydratorPrefix) == 0)
+ {
+ Storage = std::make_unique<FileStorage>();
+ }
+ else if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0)
+ {
+ Storage = std::make_unique<S3Storage>();
+ }
+ else
+ {
+ throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification));
+ }
}
- if (StrCaseCompare(Config.TargetSpecification.substr(0, S3HydratorPrefix.length()), S3HydratorPrefix) == 0)
+ else
{
- std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<S3Hydrator>();
- Hydrator->Configure(Config);
- return Hydrator;
+ std::string_view Type = Config.Options["type"].AsString();
+ if (Type == FileHydratorType)
+ {
+ Storage = std::make_unique<FileStorage>();
+ }
+ else if (Type == S3HydratorType)
+ {
+ Storage = std::make_unique<S3Storage>();
+ }
+ else if (!Type.empty())
+ {
+ throw zen::runtime_error("Unknown hydration target type '{}'", Type);
+ }
+ else
+ {
+ throw zen::runtime_error("No hydration target configured");
+ }
}
- throw std::runtime_error(fmt::format("Unknown hydration strategy: {}", Config.TargetSpecification));
+
+ auto Hydrator = std::make_unique<IncrementalHydrator>(std::move(Storage));
+ Hydrator->Configure(Config);
+ return Hydrator;
}
#if ZEN_WITH_TESTS
namespace {
+ struct TestThreading
+ {
+ WorkerThreadPool WorkerPool;
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
+ HydrationConfig::ThreadingOptions Options{.WorkerPool = &WorkerPool, .AbortFlag = &AbortFlag, .PauseFlag = &PauseFlag};
+
+ explicit TestThreading(int ThreadCount) : WorkerPool(ThreadCount) {}
+ };
+
/// Scoped RAII helper to set/restore a single environment variable within a test.
/// Used to configure AWS credentials for each S3 test's MinIO instance
/// without polluting the global environment.
@@ -593,10 +1161,10 @@ namespace {
/// subdir/file_b.bin
/// subdir/nested/file_c.bin
/// Returns a vector of (relative path, content) pairs for later verification.
- std::vector<std::pair<std::filesystem::path, IoBuffer>> CreateTestTree(const std::filesystem::path& BaseDir)
- {
- std::vector<std::pair<std::filesystem::path, IoBuffer>> Files;
+ typedef std::vector<std::pair<std::filesystem::path, IoBuffer>> TestFileList;
+ TestFileList AddTestFiles(const std::filesystem::path& BaseDir, TestFileList& Files)
+ {
auto AddFile = [&](std::filesystem::path RelPath, IoBuffer Content) {
std::filesystem::path FullPath = BaseDir / RelPath;
CreateDirectories(FullPath.parent_path());
@@ -607,6 +1175,36 @@ namespace {
AddFile("file_a.bin", CreateSemiRandomBlob(1024));
AddFile("subdir/file_b.bin", CreateSemiRandomBlob(2048));
AddFile("subdir/nested/file_c.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/file_d.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/file_e.bin", CreateSemiRandomBlob(512));
+ AddFile("subdir/nested/file_f.bin", CreateSemiRandomBlob(512));
+
+ return Files;
+ }
+
+ TestFileList CreateSmallTestTree(const std::filesystem::path& BaseDir)
+ {
+ TestFileList Files;
+ AddTestFiles(BaseDir, Files);
+ return Files;
+ }
+
+ TestFileList CreateTestTree(const std::filesystem::path& BaseDir)
+ {
+ TestFileList Files;
+ AddTestFiles(BaseDir, Files);
+
+ auto AddFile = [&](std::filesystem::path RelPath, IoBuffer Content) {
+ std::filesystem::path FullPath = BaseDir / RelPath;
+ CreateDirectories(FullPath.parent_path());
+ WriteFile(FullPath, Content);
+ Files.emplace_back(std::move(RelPath), std::move(Content));
+ };
+
+ AddFile("subdir/nested/medium.bulk", CreateSemiRandomBlob(256u * 1024u));
+ AddFile("subdir/nested/big.bulk", CreateSemiRandomBlob(512u * 1024u));
+ AddFile("subdir/nested/huge.bulk", CreateSemiRandomBlob(9u * 1024u * 1024u));
+ AddFile("subdir/nested/biggest.bulk", CreateSemiRandomBlob(63u * 1024u * 1024u));
return Files;
}
@@ -644,7 +1242,7 @@ TEST_CASE("hydration.file.dehydrate_hydrate")
CreateDirectories(HydrationTemp);
const std::string ModuleId = "testmodule";
- auto TestFiles = CreateTestTree(ServerStateDir);
+ auto TestFiles = CreateSmallTestTree(ServerStateDir);
HydrationConfig Config;
Config.ServerStateDir = ServerStateDir;
@@ -655,7 +1253,7 @@ TEST_CASE("hydration.file.dehydrate_hydrate")
// Dehydrate: copy server state to file store
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Dehydrate(CbObject());
}
// Verify the module folder exists in the store and ServerStateDir was wiped
@@ -672,7 +1270,7 @@ TEST_CASE("hydration.file.dehydrate_hydrate")
VerifyTree(ServerStateDir, TestFiles);
}
-TEST_CASE("hydration.file.dehydrate_cleans_server_state")
+TEST_CASE("hydration.file.hydrate_overwrites_existing_state")
{
ScopedTemporaryDirectory TempDir;
@@ -683,7 +1281,7 @@ TEST_CASE("hydration.file.dehydrate_cleans_server_state")
CreateDirectories(HydrationStore);
CreateDirectories(HydrationTemp);
- CreateTestTree(ServerStateDir);
+ auto TestFiles = CreateSmallTestTree(ServerStateDir);
HydrationConfig Config;
Config.ServerStateDir = ServerStateDir;
@@ -691,14 +1289,26 @@ TEST_CASE("hydration.file.dehydrate_cleans_server_state")
Config.ModuleId = "testmodule";
Config.TargetSpecification = "file://" + HydrationStore.string();
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ // Dehydrate the original state
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(CbObject());
+ }
- // FileHydrator::Dehydrate() must wipe ServerStateDir when done
- CHECK(std::filesystem::is_empty(ServerStateDir));
+ // Put a stale file in ServerStateDir to simulate leftover state
+ WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256));
+
+ // Hydrate - must wipe stale file and restore original
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Hydrate();
+ }
+
+ CHECK_FALSE(std::filesystem::exists(ServerStateDir / "stale.bin"));
+ VerifyTree(ServerStateDir, TestFiles);
}
-TEST_CASE("hydration.file.hydrate_overwrites_existing_state")
+TEST_CASE("hydration.file.excluded_files_not_dehydrated")
{
ScopedTemporaryDirectory TempDir;
@@ -709,31 +1319,86 @@ TEST_CASE("hydration.file.hydrate_overwrites_existing_state")
CreateDirectories(HydrationStore);
CreateDirectories(HydrationTemp);
- auto TestFiles = CreateTestTree(ServerStateDir);
+ auto TestFiles = CreateSmallTestTree(ServerStateDir);
+
+ // Add files that the dehydrator should skip
+ WriteFile(ServerStateDir / "reserve.gc", CreateSemiRandomBlob(64));
+ CreateDirectories(ServerStateDir / ".sentry-native");
+ WriteFile(ServerStateDir / ".sentry-native" / "db.lock", CreateSemiRandomBlob(32));
+ WriteFile(ServerStateDir / ".sentry-native" / "breadcrumb.json", CreateSemiRandomBlob(128));
HydrationConfig Config;
Config.ServerStateDir = ServerStateDir;
Config.TempDir = HydrationTemp;
- Config.ModuleId = "testmodule";
+ Config.ModuleId = "testmodule_excl";
Config.TargetSpecification = "file://" + HydrationStore.string();
- // Dehydrate the original state
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Dehydrate(CbObject());
}
- // Put a stale file in ServerStateDir to simulate leftover state
- WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256));
-
- // Hydrate - must wipe stale file and restore original
+ // Hydrate into a clean directory
+ CleanDirectory(ServerStateDir, true);
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
Hydrator->Hydrate();
}
- CHECK_FALSE(std::filesystem::exists(ServerStateDir / "stale.bin"));
+ // Normal files must be restored
VerifyTree(ServerStateDir, TestFiles);
+ // Excluded files must NOT be restored
+ CHECK_FALSE(std::filesystem::exists(ServerStateDir / "reserve.gc"));
+ CHECK_FALSE(std::filesystem::exists(ServerStateDir / ".sentry-native"));
+}
+
+// ---------------------------------------------------------------------------
+// FileHydrator obliterate test
+// ---------------------------------------------------------------------------
+
+TEST_CASE("hydration.file.obliterate")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
+ std::filesystem::path HydrationStore = TempDir.Path() / "hydration_store";
+ std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
+ CreateDirectories(ServerStateDir);
+ CreateDirectories(HydrationStore);
+ CreateDirectories(HydrationTemp);
+
+ const std::string ModuleId = "obliterate_test";
+ CreateSmallTestTree(ServerStateDir);
+
+ HydrationConfig Config;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = ModuleId;
+ Config.TargetSpecification = "file://" + HydrationStore.string();
+
+ // Dehydrate so the backend store has data
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(CbObject());
+ }
+ CHECK(std::filesystem::exists(HydrationStore / ModuleId));
+
+ // Put some files back in ServerStateDir and TempDir to verify cleanup
+ CreateSmallTestTree(ServerStateDir);
+ WriteFile(HydrationTemp / "leftover.tmp", CreateSemiRandomBlob(64));
+
+ // Obliterate
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Obliterate();
+ }
+
+ // Backend store directory deleted
+ CHECK_FALSE(std::filesystem::exists(HydrationStore / ModuleId));
+ // ServerStateDir cleaned
+ CHECK(std::filesystem::is_empty(ServerStateDir));
+ // TempDir cleaned
+ CHECK(std::filesystem::is_empty(HydrationTemp));
}
// ---------------------------------------------------------------------------
@@ -750,6 +1415,8 @@ TEST_CASE("hydration.file.concurrent")
std::filesystem::path HydrationStore = TempDir.Path() / "hydration_store";
CreateDirectories(HydrationStore);
+ TestThreading Threading(8);
+
struct ModuleData
{
HydrationConfig Config;
@@ -769,7 +1436,8 @@ TEST_CASE("hydration.file.concurrent")
Modules[I].Config.TempDir = TempPath;
Modules[I].Config.ModuleId = ModuleId;
Modules[I].Config.TargetSpecification = "file://" + HydrationStore.string();
- Modules[I].Files = CreateTestTree(StateDir);
+ Modules[I].Config.Threading = Threading.Options;
+ Modules[I].Files = CreateSmallTestTree(StateDir);
}
// Concurrent dehydrate
@@ -783,7 +1451,7 @@ TEST_CASE("hydration.file.concurrent")
{
Work.ScheduleWork(Pool, [&Config = Modules[I].Config](std::atomic<bool>&) {
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Dehydrate(CbObject());
});
}
Work.Wait();
@@ -818,14 +1486,14 @@ TEST_CASE("hydration.file.concurrent")
// ---------------------------------------------------------------------------
// S3Hydrator tests
//
-// Each test case spawns its own local MinIO instance (self-contained, no external setup needed).
+// Each test case spawns a local MinIO instance (self-contained, no external setup needed).
// The MinIO binary must be present in the same directory as the test executable (copied by xmake).
// ---------------------------------------------------------------------------
TEST_CASE("hydration.s3.dehydrate_hydrate")
{
MinioProcessOptions MinioOpts;
- MinioOpts.Port = 19010;
+ MinioOpts.Port = 19011;
MinioProcess Minio(MinioOpts);
Minio.SpawnMinioServer();
Minio.CreateBucket("zen-hydration-test");
@@ -840,168 +1508,57 @@ TEST_CASE("hydration.s3.dehydrate_hydrate")
CreateDirectories(ServerStateDir);
CreateDirectories(HydrationTemp);
- const std::string ModuleId = "s3test_roundtrip";
- auto TestFiles = CreateTestTree(ServerStateDir);
-
HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = ModuleId;
- Config.TargetSpecification = "s3://zen-hydration-test";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
-
- // Dehydrate: upload server state to MinIO
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = "s3test_roundtrip";
{
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
}
- // Wipe server state
- CleanDirectory(ServerStateDir, true);
- CHECK(std::filesystem::is_empty(ServerStateDir));
-
- // Hydrate: download from MinIO back to server state
+ // Hydrate with no prior S3 state (first-boot path). Pre-populate ServerStateDir
+ // with a stale file to confirm the cleanup branch wipes it.
+ WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256));
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
Hydrator->Hydrate();
}
-
- // Verify restored contents match the original
- VerifyTree(ServerStateDir, TestFiles);
-}
-
-TEST_CASE("hydration.s3.current_state_json_selects_latest_folder")
-{
- // Each Dehydrate() uploads files to a new timestamp-named folder and then overwrites
- // current-state.json to point at that folder. Old folders are NOT deleted.
- // Hydrate() must read current-state.json to determine which folder to restore from.
- //
- // This test verifies that:
- // 1. After two dehydrations, Hydrate() restores from the second snapshot, not the first,
- // confirming that current-state.json was updated between dehydrations.
- // 2. current-state.json is updated to point at the second (latest) folder.
- // 3. Hydrate() restores the v2 snapshot (identified by v2marker.bin), NOT the v1 snapshot.
-
- MinioProcessOptions MinioOpts;
- MinioOpts.Port = 19011;
- MinioProcess Minio(MinioOpts);
- Minio.SpawnMinioServer();
- Minio.CreateBucket("zen-hydration-test");
-
- ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
- ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
-
- ScopedTemporaryDirectory TempDir;
-
- std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
- std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
- CreateDirectories(ServerStateDir);
- CreateDirectories(HydrationTemp);
-
- const std::string ModuleId = "s3test_folder_select";
-
- HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = ModuleId;
- Config.TargetSpecification = "s3://zen-hydration-test";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ CHECK(std::filesystem::is_empty(ServerStateDir));
// v1: dehydrate without a marker file
- CreateTestTree(ServerStateDir);
+ CreateSmallTestTree(ServerStateDir);
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Dehydrate(CbObject());
}
- // ServerStateDir is now empty. Wait so the v2 timestamp folder name is strictly later
- // (timestamp resolution is 1 ms, but macOS scheduler granularity requires a larger margin).
- Sleep(100);
-
// v2: dehydrate WITH a marker file that only v2 has
- CreateTestTree(ServerStateDir);
+ CreateSmallTestTree(ServerStateDir);
WriteFile(ServerStateDir / "v2marker.bin", CreateSemiRandomBlob(64));
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Dehydrate(CbObject());
}
- // Hydrate must restore v2 (current-state.json points to the v2 folder)
+ // Hydrate must restore v2 (the latest dehydrated state)
CleanDirectory(ServerStateDir, true);
{
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
Hydrator->Hydrate();
}
- // v2 marker must be present - confirms current-state.json pointed to the v2 folder
+ // v2 marker must be present - confirms the second dehydration overwrote the first
CHECK(std::filesystem::exists(ServerStateDir / "v2marker.bin"));
- // Subdirectory hierarchy must also be intact
CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "file_b.bin"));
CHECK(std::filesystem::exists(ServerStateDir / "subdir" / "nested" / "file_c.bin"));
}
-TEST_CASE("hydration.s3.module_isolation")
-{
- // Two independent modules dehydrate/hydrate without interfering with each other.
- // Uses VerifyTree with per-module byte content to detect cross-module data mixing.
- MinioProcessOptions MinioOpts;
- MinioOpts.Port = 19012;
- MinioProcess Minio(MinioOpts);
- Minio.SpawnMinioServer();
- Minio.CreateBucket("zen-hydration-test");
-
- ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
- ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
-
- ScopedTemporaryDirectory TempDir;
-
- struct ModuleData
- {
- HydrationConfig Config;
- std::vector<std::pair<std::filesystem::path, IoBuffer>> Files;
- };
-
- std::vector<ModuleData> Modules;
- for (const char* ModuleId : {"s3test_iso_a", "s3test_iso_b"})
- {
- std::filesystem::path StateDir = TempDir.Path() / ModuleId / "state";
- std::filesystem::path TempPath = TempDir.Path() / ModuleId / "temp";
- CreateDirectories(StateDir);
- CreateDirectories(TempPath);
-
- ModuleData Data;
- Data.Config.ServerStateDir = StateDir;
- Data.Config.TempDir = TempPath;
- Data.Config.ModuleId = ModuleId;
- Data.Config.TargetSpecification = "s3://zen-hydration-test";
- Data.Config.S3Endpoint = Minio.Endpoint();
- Data.Config.S3PathStyle = true;
- Data.Files = CreateTestTree(StateDir);
-
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Data.Config);
- Hydrator->Dehydrate();
-
- Modules.push_back(std::move(Data));
- }
-
- for (ModuleData& Module : Modules)
- {
- CleanDirectory(Module.Config.ServerStateDir, true);
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Module.Config);
- Hydrator->Hydrate();
-
- // Each module's files must be independently restorable with correct byte content.
- // If S3 key prefixes were mixed up, CreateSemiRandomBlob content would differ.
- VerifyTree(Module.Config.ServerStateDir, Module.Files);
- }
-}
-
-// ---------------------------------------------------------------------------
-// S3Hydrator concurrent test
-// ---------------------------------------------------------------------------
-
TEST_CASE("hydration.s3.concurrent")
{
// N modules dehydrate and hydrate concurrently against MinIO.
@@ -1015,7 +1572,10 @@ TEST_CASE("hydration.s3.concurrent")
ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
- constexpr int kModuleCount = 4;
+ constexpr int kModuleCount = 6;
+ constexpr int kThreadCount = 4;
+
+ TestThreading Threading(kThreadCount);
ScopedTemporaryDirectory TempDir;
@@ -1034,18 +1594,25 @@ TEST_CASE("hydration.s3.concurrent")
CreateDirectories(StateDir);
CreateDirectories(TempPath);
- Modules[I].Config.ServerStateDir = StateDir;
- Modules[I].Config.TempDir = TempPath;
- Modules[I].Config.ModuleId = ModuleId;
- Modules[I].Config.TargetSpecification = "s3://zen-hydration-test";
- Modules[I].Config.S3Endpoint = Minio.Endpoint();
- Modules[I].Config.S3PathStyle = true;
- Modules[I].Files = CreateTestTree(StateDir);
+ Modules[I].Config.ServerStateDir = StateDir;
+ Modules[I].Config.TempDir = TempPath;
+ Modules[I].Config.ModuleId = ModuleId;
+ Modules[I].Config.Threading = Threading.Options;
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Modules[I].Config.Options = std::move(Root).AsObject();
+ }
+ Modules[I].Files = CreateTestTree(StateDir);
}
// Concurrent dehydrate
{
- WorkerThreadPool Pool(kModuleCount, "hydration_s3_dehy");
+ WorkerThreadPool Pool(kThreadCount, "hydration_s3_dehy");
std::atomic<bool> AbortFlag{false};
std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
@@ -1054,7 +1621,7 @@ TEST_CASE("hydration.s3.concurrent")
{
Work.ScheduleWork(Pool, [&Config = Modules[I].Config](std::atomic<bool>&) {
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Dehydrate(CbObject());
});
}
Work.Wait();
@@ -1063,7 +1630,7 @@ TEST_CASE("hydration.s3.concurrent")
// Concurrent hydrate
{
- WorkerThreadPool Pool(kModuleCount, "hydration_s3_hy");
+ WorkerThreadPool Pool(kThreadCount, "hydration_s3_hy");
std::atomic<bool> AbortFlag{false};
std::atomic<bool> PauseFlag{false};
ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::EnableBacklog);
@@ -1087,17 +1654,82 @@ TEST_CASE("hydration.s3.concurrent")
}
}
-// ---------------------------------------------------------------------------
-// S3Hydrator: no prior state (first-boot path)
-// ---------------------------------------------------------------------------
+TEST_CASE("hydration.s3.obliterate")
+{
+ MinioProcessOptions MinioOpts;
+ MinioOpts.Port = 19019;
+ MinioProcess Minio(MinioOpts);
+ Minio.SpawnMinioServer();
+ Minio.CreateBucket("zen-hydration-test");
+
+ ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
+ ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
+
+ ScopedTemporaryDirectory TempDir;
+
+ std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
+ std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
+ CreateDirectories(ServerStateDir);
+ CreateDirectories(HydrationTemp);
+
+ const std::string ModuleId = "s3test_obliterate";
+
+ HydrationConfig Config;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = ModuleId;
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
+
+ // Dehydrate to populate backend
+ CreateSmallTestTree(ServerStateDir);
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(CbObject());
+ }
+
+ auto ListModuleObjects = [&]() {
+ S3ClientOptions Opts;
+ Opts.BucketName = "zen-hydration-test";
+ Opts.Endpoint = Minio.Endpoint();
+ Opts.PathStyle = true;
+ Opts.Credentials.AccessKeyId = Minio.RootUser();
+ Opts.Credentials.SecretAccessKey = Minio.RootPassword();
+ S3Client Client(Opts);
+ return Client.ListObjects(ModuleId + "/");
+ };
+
+ // Verify objects exist in S3
+ CHECK(!ListModuleObjects().Objects.empty());
+
+ // Re-populate ServerStateDir and TempDir for cleanup verification
+ CreateSmallTestTree(ServerStateDir);
+ WriteFile(HydrationTemp / "leftover.tmp", CreateSemiRandomBlob(64));
+
+ // Obliterate
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Obliterate();
+ }
+
+ // Verify S3 objects deleted
+ CHECK(ListModuleObjects().Objects.empty());
+ // Local directories cleaned
+ CHECK(std::filesystem::is_empty(ServerStateDir));
+ CHECK(std::filesystem::is_empty(HydrationTemp));
+}
-TEST_CASE("hydration.s3.no_prior_state")
+TEST_CASE("hydration.s3.config_overrides")
{
- // Hydrate() against an empty bucket (first-boot scenario) must leave ServerStateDir empty.
- // The "No state found in S3" path goes through the error-cleanup branch, which wipes
- // ServerStateDir to ensure no partial or stale content is left for the server to start on.
MinioProcessOptions MinioOpts;
- MinioOpts.Port = 19014;
+ MinioOpts.Port = 19015;
MinioProcess Minio(MinioOpts);
Minio.SpawnMinioServer();
Minio.CreateBucket("zen-hydration-test");
@@ -1112,36 +1744,244 @@ TEST_CASE("hydration.s3.no_prior_state")
CreateDirectories(ServerStateDir);
CreateDirectories(HydrationTemp);
- // Pre-populate ServerStateDir to confirm the wipe actually runs.
- WriteFile(ServerStateDir / "stale.bin", CreateSemiRandomBlob(256));
+ // Path prefix: "s3://bucket/some/prefix" stores objects under
+ // "some/prefix/<ModuleId>/..." rather than directly under "<ModuleId>/...".
+ {
+ auto TestFiles = CreateSmallTestTree(ServerStateDir);
+
+ HydrationConfig Config;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = "s3test_prefix";
+ {
+ std::string ConfigJson = fmt::format(
+ R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test/team/project","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
+
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(CbObject());
+ }
+
+ CleanDirectory(ServerStateDir, true);
+
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Hydrate();
+ }
+
+ VerifyTree(ServerStateDir, TestFiles);
+ }
+
+ // Region override: 'region' in Options["settings"] takes precedence over AWS_DEFAULT_REGION.
+ // AWS_DEFAULT_REGION is set to a bogus value; hydration must succeed using the region from Options.
+ {
+ CleanDirectory(ServerStateDir, true);
+ auto TestFiles = CreateSmallTestTree(ServerStateDir);
+
+ ScopedEnvVar EnvRegion("AWS_DEFAULT_REGION", "wrong-region");
+
+ HydrationConfig Config;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = "s3test_region_override";
+ {
+ std::string ConfigJson = fmt::format(
+ R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true,"region":"us-east-1"}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
+
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(CbObject());
+ }
+
+ CleanDirectory(ServerStateDir, true);
+
+ {
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Hydrate();
+ }
+
+ VerifyTree(ServerStateDir, TestFiles);
+ }
+}
+
+TEST_CASE("hydration.s3.dehydrate_hydrate.performance" * doctest::skip())
+{
+ MinioProcessOptions MinioOpts;
+ MinioOpts.Port = 19010;
+ MinioProcess Minio(MinioOpts);
+ Minio.SpawnMinioServer();
+ Minio.CreateBucket("zen-hydration-test");
+
+ ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
+ ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
+
+ ScopedTemporaryDirectory TempDir;
+
+ std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
+ std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
+ CreateDirectories(ServerStateDir);
+ CreateDirectories(HydrationTemp);
+
+ const std::string ModuleId = "s3test_performance";
+ CopyTree("E:\\Dev\\hub\\brainrot\\20260402-225355-508", ServerStateDir, {.EnableClone = true});
+ // auto TestFiles = CreateTestTree(ServerStateDir);
+
+ TestThreading Threading(4);
+
+ HydrationConfig Config;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = ModuleId;
+ Config.Threading = Threading.Options;
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+
+ // Dehydrate: upload server state to MinIO
+ {
+ ZEN_INFO("============== DEHYDRATE ==============");
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(CbObject());
+ }
+
+ for (size_t I = 0; I < 1; I++)
+ {
+ // Wipe server state
+ CleanDirectory(ServerStateDir, true);
+ CHECK(std::filesystem::is_empty(ServerStateDir));
+
+ // Hydrate: download from MinIO back to server state
+ {
+ ZEN_INFO("=============== HYDRATE ===============");
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Hydrate();
+ }
+ }
+}
+
+//#define REAL_DATA_PATH "E:\\Dev\\hub\\zenddc\\Zen"
+//#define REAL_DATA_PATH "E:\\Dev\\hub\\brainrot\\20260402-225355-508"
+
+TEST_CASE("hydration.file.incremental")
+{
+ std::filesystem::path TmpPath;
+# ifdef REAL_DATA_PATH
+ TmpPath = std::filesystem::path(REAL_DATA_PATH).parent_path() / "hub";
+# endif
+ ScopedTemporaryDirectory TempDir(TmpPath);
+
+ std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
+ std::filesystem::path HydrationStore = TempDir.Path() / "hydration_store";
+ std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
+ CreateDirectories(ServerStateDir);
+ CreateDirectories(HydrationStore);
+ CreateDirectories(HydrationTemp);
+
+ const std::string ModuleId = "testmodule";
+ // auto TestFiles = CreateTestTree(ServerStateDir);
+
+ TestThreading Threading(4);
HydrationConfig Config;
Config.ServerStateDir = ServerStateDir;
Config.TempDir = HydrationTemp;
- Config.ModuleId = "s3test_no_prior";
- Config.TargetSpecification = "s3://zen-hydration-test";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ Config.ModuleId = ModuleId;
+ Config.TargetSpecification = "file://" + HydrationStore.string();
+ Config.Threading = Threading.Options;
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Hydrate();
+ std::unique_ptr<StorageBase> Storage = std::make_unique<FileStorage>();
+ std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<IncrementalHydrator>(std::move(Storage));
- // ServerStateDir must be empty: the error path wipes it to prevent a server start
- // against stale or partially-installed content.
+ // Hydrate with no prior state
+ CbObject HydrationState;
+ {
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ CHECK_FALSE(HydrationState);
+ }
+
+# ifdef REAL_DATA_PATH
+ ZEN_INFO("Writing state data...");
+ CopyTree(REAL_DATA_PATH, ServerStateDir, {.EnableClone = true});
+ ZEN_INFO("Writing state data complete");
+# else
+ // Create test files and dehydrate
+ auto TestFiles = CreateTestTree(ServerStateDir);
+# endif
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
+ CHECK(std::filesystem::is_empty(ServerStateDir));
+
+ // Hydrate: restore from S3
+ {
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ }
+# ifndef REAL_DATA_PATH
+ VerifyTree(ServerStateDir, TestFiles);
+# endif
+ // Dehydrate again with cached state (should skip re-uploading unchanged files)
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
CHECK(std::filesystem::is_empty(ServerStateDir));
+
+ // Hydrate one more time to confirm second dehydrate produced valid state
+ {
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ }
+
+ // Replace files and dehydrate
+ TestFiles = CreateTestTree(ServerStateDir);
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
+
+ // Hydrate one more time to confirm second dehydrate produced valid state
+ {
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ }
+# ifndef REAL_DATA_PATH
+ VerifyTree(ServerStateDir, TestFiles);
+# endif // 0
+
+ // Dehydrate, nothing touched - no hashing, no upload
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
}
// ---------------------------------------------------------------------------
-// S3Hydrator: bucket path prefix in TargetSpecification
+// S3Storage test
// ---------------------------------------------------------------------------
-TEST_CASE("hydration.s3.path_prefix")
+TEST_CASE("hydration.s3.incremental")
{
- // TargetSpecification of the form "s3://bucket/some/prefix" stores objects under
- // "some/prefix/<ModuleId>/..." rather than directly under "<ModuleId>/...".
- // Tests the second branch of the m_KeyPrefix calculation in S3Hydrator::Configure().
MinioProcessOptions MinioOpts;
- MinioOpts.Port = 19015;
+ MinioOpts.Port = 19017;
MinioProcess Minio(MinioOpts);
Minio.SpawnMinioServer();
Minio.CreateBucket("zen-hydration-test");
@@ -1149,36 +1989,132 @@ TEST_CASE("hydration.s3.path_prefix")
ScopedEnvVar EnvAccessKey("AWS_ACCESS_KEY_ID", Minio.RootUser());
ScopedEnvVar EnvSecretKey("AWS_SECRET_ACCESS_KEY", Minio.RootPassword());
- ScopedTemporaryDirectory TempDir;
+ std::filesystem::path TmpPath;
+# ifdef REAL_DATA_PATH
+ TmpPath = std::filesystem::path(REAL_DATA_PATH).parent_path() / "hub";
+# endif
+ ScopedTemporaryDirectory TempDir(TmpPath);
std::filesystem::path ServerStateDir = TempDir.Path() / "server_state";
std::filesystem::path HydrationTemp = TempDir.Path() / "hydration_temp";
CreateDirectories(ServerStateDir);
CreateDirectories(HydrationTemp);
- std::vector<std::pair<std::filesystem::path, IoBuffer>> TestFiles = CreateTestTree(ServerStateDir);
+ const std::string ModuleId = "s3test_incremental";
+
+ TestThreading Threading(8);
HydrationConfig Config;
- Config.ServerStateDir = ServerStateDir;
- Config.TempDir = HydrationTemp;
- Config.ModuleId = "s3test_prefix";
- Config.TargetSpecification = "s3://zen-hydration-test/team/project";
- Config.S3Endpoint = Minio.Endpoint();
- Config.S3PathStyle = true;
+ Config.ServerStateDir = ServerStateDir;
+ Config.TempDir = HydrationTemp;
+ Config.ModuleId = ModuleId;
+ Config.Threading = Threading.Options;
+ {
+ std::string ConfigJson =
+ fmt::format(R"({{"type":"s3","settings":{{"uri":"s3://zen-hydration-test","endpoint":"{}","path-style":true}}}})",
+ Minio.Endpoint());
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(ConfigJson, ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
+
+ std::unique_ptr<StorageBase> Storage = std::make_unique<S3Storage>();
+ std::unique_ptr<HydrationStrategyBase> Hydrator = std::make_unique<IncrementalHydrator>(std::move(Storage));
+ // Hydrate with no prior state
+ CbObject HydrationState;
{
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Dehydrate();
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ CHECK_FALSE(HydrationState);
}
- CleanDirectory(ServerStateDir, true);
+# ifdef REAL_DATA_PATH
+ ZEN_INFO("Writing state data...");
+ CopyTree(REAL_DATA_PATH, ServerStateDir, {.EnableClone = true});
+ ZEN_INFO("Writing state data complete");
+# else
+ // Create test files and dehydrate
+ auto TestFiles = CreateTestTree(ServerStateDir);
+# endif
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
+ CHECK(std::filesystem::is_empty(ServerStateDir));
+ // Hydrate: restore from S3
{
- std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
- Hydrator->Hydrate();
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ }
+# ifndef REAL_DATA_PATH
+ VerifyTree(ServerStateDir, TestFiles);
+# endif
+ // Dehydrate again with cached state (should skip re-uploading unchanged files)
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
+ CHECK(std::filesystem::is_empty(ServerStateDir));
+
+ // Hydrate one more time to confirm second dehydrate produced valid state
+ {
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ }
+
+ // Replace files and dehydrate
+ TestFiles = CreateTestTree(ServerStateDir);
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
}
+ // Hydrate one more time to confirm second dehydrate produced valid state
+ {
+ Hydrator->Configure(Config);
+ HydrationState = Hydrator->Hydrate();
+ }
+
+# ifndef REAL_DATA_PATH
VerifyTree(ServerStateDir, TestFiles);
+# endif // 0
+
+ // Dehydrate, nothing touched - no hashing, no upload
+ {
+ Hydrator->Configure(Config);
+ Hydrator->Dehydrate(HydrationState);
+ }
+}
+
+TEST_CASE("hydration.create_hydrator_rejects_invalid_config")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ HydrationConfig Config;
+ Config.ServerStateDir = TempDir.Path() / "state";
+ Config.TempDir = TempDir.Path() / "temp";
+ Config.ModuleId = "invalid_test";
+
+ // Unknown TargetSpecification prefix
+ Config.TargetSpecification = "ftp://somewhere";
+ CHECK_THROWS(CreateHydrator(Config));
+
+ // Unknown Options type
+ Config.TargetSpecification.clear();
+ {
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(R"({"type":"dynamodb"})", ParseError);
+ ZEN_ASSERT(ParseError.empty() && Root.IsObject());
+ Config.Options = std::move(Root).AsObject();
+ }
+ CHECK_THROWS(CreateHydrator(Config));
+
+ // Empty Options (no type field)
+ Config.Options = CbObject();
+ CHECK_THROWS(CreateHydrator(Config));
}
TEST_SUITE_END();
diff --git a/src/zenserver/hub/hydration.h b/src/zenserver/hub/hydration.h
index d29ffe5c0..fc2f309b2 100644
--- a/src/zenserver/hub/hydration.h
+++ b/src/zenserver/hub/hydration.h
@@ -2,10 +2,15 @@
#pragma once
+#include <zencore/compactbinary.h>
+
#include <filesystem>
+#include <optional>
namespace zen {
+class WorkerThreadPool;
+
struct HydrationConfig
{
// Location of server state to hydrate/dehydrate
@@ -16,12 +21,18 @@ struct HydrationConfig
std::string ModuleId;
// Back-end specific target specification (e.g. S3 bucket, file path, etc)
std::string TargetSpecification;
+ // Full config object when using --hub-hydration-target-config (mutually exclusive with TargetSpecification)
+ CbObject Options;
+
+ struct ThreadingOptions
+ {
+ WorkerThreadPool* WorkerPool;
+ std::atomic<bool>* AbortFlag;
+ std::atomic<bool>* PauseFlag;
+ };
- // Optional S3 endpoint override (e.g. "http://localhost:9000" for MinIO).
- std::string S3Endpoint;
- // Use path-style S3 URLs (endpoint/bucket/key) instead of virtual-hosted-style
- // (bucket.endpoint/key). Required for MinIO and other non-AWS endpoints.
- bool S3PathStyle = false;
+ // External threading for parallel I/O and hashing. If not set, work runs inline on the caller's thread.
+ std::optional<ThreadingOptions> Threading;
};
/**
@@ -36,11 +47,22 @@ struct HydrationStrategyBase
{
virtual ~HydrationStrategyBase() = default;
- virtual void Dehydrate() = 0;
- virtual void Hydrate() = 0;
+ // Set up the hydration target from Config. Must be called before Hydrate/Dehydrate.
virtual void Configure(const HydrationConfig& Config) = 0;
+
+ // Upload server state to the configured target. ServerStateDir is wiped on success.
+ // On failure, ServerStateDir is left intact.
+ virtual void Dehydrate(const CbObject& CachedState) = 0;
+
+ // Download state from the configured target into ServerStateDir. Returns cached state for the next Dehydrate.
+ // On failure, ServerStateDir is wiped and an empty CbObject is returned.
+ virtual CbObject Hydrate() = 0;
+
+ // Delete all stored data for this module from the configured backend, then clean ServerStateDir and TempDir.
+ virtual void Obliterate() = 0;
};
+// Create a configured hydrator based on Config. Ready to call Hydrate/Dehydrate immediately.
std::unique_ptr<HydrationStrategyBase> CreateHydrator(const HydrationConfig& Config);
#if ZEN_WITH_TESTS
diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp
index 6b139dbf1..af2c19113 100644
--- a/src/zenserver/hub/storageserverinstance.cpp
+++ b/src/zenserver/hub/storageserverinstance.cpp
@@ -16,8 +16,6 @@ StorageServerInstance::StorageServerInstance(ZenServerEnvironment& RunEnvironmen
, m_ModuleId(ModuleId)
, m_ServerInstance(RunEnvironment, ZenServerInstance::ServerMode::kStorageServer)
{
- m_BaseDir = RunEnvironment.CreateChildDir(ModuleId);
- m_TempDir = Config.HydrationTempPath / ModuleId;
}
StorageServerInstance::~StorageServerInstance()
@@ -31,7 +29,7 @@ StorageServerInstance::SpawnServerProcess()
m_ServerInstance.ResetDeadProcess();
m_ServerInstance.SetServerExecutablePath(GetRunningExecutablePath());
- m_ServerInstance.SetDataDir(m_BaseDir);
+ m_ServerInstance.SetDataDir(m_Config.StateDir);
#if ZEN_PLATFORM_WINDOWS
m_ServerInstance.SetJobObject(m_JobObject);
#endif
@@ -50,6 +48,36 @@ StorageServerInstance::SpawnServerProcess()
{
AdditionalOptions << " --config=\"" << MakeSafeAbsolutePath(m_Config.ConfigPath).string() << "\"";
}
+ if (!m_Config.Malloc.empty())
+ {
+ AdditionalOptions << " --malloc=" << m_Config.Malloc;
+ }
+ if (!m_Config.Trace.empty())
+ {
+ AdditionalOptions << " --trace=" << m_Config.Trace;
+ }
+ if (!m_Config.TraceHost.empty())
+ {
+ AdditionalOptions << " --tracehost=" << m_Config.TraceHost;
+ }
+ if (!m_Config.TraceFile.empty())
+ {
+ constexpr std::string_view ModuleIdPattern = "{moduleid}";
+ constexpr std::string_view PortPattern = "{port}";
+
+ std::string ResolvedTraceFile = m_Config.TraceFile;
+ for (size_t Pos = ResolvedTraceFile.find(ModuleIdPattern); Pos != std::string::npos;
+ Pos = ResolvedTraceFile.find(ModuleIdPattern, Pos))
+ {
+ ResolvedTraceFile.replace(Pos, ModuleIdPattern.length(), m_ModuleId);
+ }
+ std::string PortStr = fmt::format("{}", m_Config.BasePort);
+ for (size_t Pos = ResolvedTraceFile.find(PortPattern); Pos != std::string::npos; Pos = ResolvedTraceFile.find(PortPattern, Pos))
+ {
+ ResolvedTraceFile.replace(Pos, PortPattern.length(), PortStr);
+ }
+ AdditionalOptions << " --tracefile=\"" << ResolvedTraceFile << "\"";
+ }
m_ServerInstance.SpawnServerAndWaitUntilReady(m_Config.BasePort, AdditionalOptions.ToView());
ZEN_DEBUG("Storage server instance for module '{}' started, listening on port {}", m_ModuleId, m_Config.BasePort);
@@ -57,16 +85,15 @@ StorageServerInstance::SpawnServerProcess()
m_ServerInstance.EnableShutdownOnDestroy();
}
-void
-StorageServerInstance::GetProcessMetrics(ProcessMetrics& OutMetrics) const
+ProcessMetrics
+StorageServerInstance::GetProcessMetrics() const
{
- OutMetrics.MemoryBytes = m_MemoryBytes.load();
- OutMetrics.KernelTimeMs = m_KernelTimeMs.load();
- OutMetrics.UserTimeMs = m_UserTimeMs.load();
- OutMetrics.WorkingSetSize = m_WorkingSetSize.load();
- OutMetrics.PeakWorkingSetSize = m_PeakWorkingSetSize.load();
- OutMetrics.PagefileUsage = m_PagefileUsage.load();
- OutMetrics.PeakPagefileUsage = m_PeakPagefileUsage.load();
+ ProcessMetrics Metrics;
+ if (m_ServerInstance.IsRunning())
+ {
+ zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics);
+ }
+ return Metrics;
}
void
@@ -78,7 +105,7 @@ StorageServerInstance::ProvisionLocked()
return;
}
- ZEN_INFO("Provisioning storage server instance for module '{}', at '{}'", m_ModuleId, m_BaseDir);
+ ZEN_INFO("Provisioning storage server instance for module '{}', at '{}'", m_ModuleId, m_Config.StateDir);
try
{
Hydrate();
@@ -88,7 +115,7 @@ StorageServerInstance::ProvisionLocked()
{
ZEN_WARN("Failed spawning server instance for module '{}', at '{}' during provisioning. Reason: {}",
m_ModuleId,
- m_BaseDir,
+ m_Config.StateDir,
Ex.what());
throw;
}
@@ -118,6 +145,22 @@ StorageServerInstance::DeprovisionLocked()
}
void
+StorageServerInstance::ObliterateLocked()
+{
+ if (m_ServerInstance.IsRunning())
+ {
+ // m_ServerInstance.Shutdown() never throws.
+ m_ServerInstance.Shutdown();
+ }
+
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
+ HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag);
+ std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Obliterate();
+}
+
+void
StorageServerInstance::HibernateLocked()
{
// Signal server to shut down, but keep data around for later wake
@@ -147,7 +190,10 @@ StorageServerInstance::WakeLocked()
}
catch (const std::exception& Ex)
{
- ZEN_WARN("Failed spawning server instance for module '{}', at '{}' during waking. Reason: {}", m_ModuleId, m_BaseDir, Ex.what());
+ ZEN_WARN("Failed spawning server instance for module '{}', at '{}' during waking. Reason: {}",
+ m_ModuleId,
+ m_Config.StateDir,
+ Ex.what());
throw;
}
}
@@ -155,27 +201,38 @@ StorageServerInstance::WakeLocked()
void
StorageServerInstance::Hydrate()
{
- HydrationConfig Config{.ServerStateDir = m_BaseDir,
- .TempDir = m_TempDir,
- .ModuleId = m_ModuleId,
- .TargetSpecification = m_Config.HydrationTargetSpecification};
-
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
+ HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag);
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
-
- Hydrator->Hydrate();
+ m_HydrationState = Hydrator->Hydrate();
}
void
StorageServerInstance::Dehydrate()
{
- HydrationConfig Config{.ServerStateDir = m_BaseDir,
- .TempDir = m_TempDir,
- .ModuleId = m_ModuleId,
- .TargetSpecification = m_Config.HydrationTargetSpecification};
-
+ std::atomic<bool> AbortFlag{false};
+ std::atomic<bool> PauseFlag{false};
+ HydrationConfig Config = MakeHydrationConfig(AbortFlag, PauseFlag);
std::unique_ptr<HydrationStrategyBase> Hydrator = CreateHydrator(Config);
+ Hydrator->Dehydrate(m_HydrationState);
+}
+
+HydrationConfig
+StorageServerInstance::MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag)
+{
+ HydrationConfig Config{.ServerStateDir = m_Config.StateDir,
+ .TempDir = m_Config.TempDir,
+ .ModuleId = m_ModuleId,
+ .TargetSpecification = m_Config.HydrationTargetSpecification,
+ .Options = m_Config.HydrationOptions};
+ if (m_Config.OptionalWorkerPool)
+ {
+ Config.Threading.emplace(
+ HydrationConfig::ThreadingOptions{.WorkerPool = m_Config.OptionalWorkerPool, .AbortFlag = &AbortFlag, .PauseFlag = &PauseFlag});
+ }
- Hydrator->Dehydrate();
+ return Config;
}
StorageServerInstance::SharedLockedPtr::SharedLockedPtr() : m_Lock(nullptr), m_Instance(nullptr)
@@ -249,25 +306,6 @@ StorageServerInstance::SharedLockedPtr::IsRunning() const
return m_Instance->m_ServerInstance.IsRunning();
}
-void
-StorageServerInstance::UpdateMetricsLocked()
-{
- if (m_ServerInstance.IsRunning())
- {
- ProcessMetrics Metrics;
- zen::GetProcessMetrics(m_ServerInstance.GetProcessHandle(), Metrics);
-
- m_MemoryBytes.store(Metrics.MemoryBytes);
- m_KernelTimeMs.store(Metrics.KernelTimeMs);
- m_UserTimeMs.store(Metrics.UserTimeMs);
- m_WorkingSetSize.store(Metrics.WorkingSetSize);
- m_PeakWorkingSetSize.store(Metrics.PeakWorkingSetSize);
- m_PagefileUsage.store(Metrics.PagefileUsage);
- m_PeakPagefileUsage.store(Metrics.PeakPagefileUsage);
- }
- // TODO: Resource metrics...
-}
-
#if ZEN_WITH_TESTS
void
StorageServerInstance::SharedLockedPtr::TerminateForTesting() const
@@ -363,6 +401,13 @@ StorageServerInstance::ExclusiveLockedPtr::Deprovision()
}
void
+StorageServerInstance::ExclusiveLockedPtr::Obliterate()
+{
+ ZEN_ASSERT(m_Instance != nullptr);
+ m_Instance->ObliterateLocked();
+}
+
+void
StorageServerInstance::ExclusiveLockedPtr::Hibernate()
{
ZEN_ASSERT(m_Instance != nullptr);
diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h
index 94c47630c..80f8a5016 100644
--- a/src/zenserver/hub/storageserverinstance.h
+++ b/src/zenserver/hub/storageserverinstance.h
@@ -2,8 +2,9 @@
#pragma once
-#include "resourcemetrics.h"
+#include "hydration.h"
+#include <zencore/compactbinary.h>
#include <zenutil/zenserverprocess.h>
#include <atomic>
@@ -11,6 +12,8 @@
namespace zen {
+class WorkerThreadPool;
+
/**
* Storage Server Instance
*
@@ -24,21 +27,27 @@ public:
struct Configuration
{
uint16_t BasePort;
- std::filesystem::path HydrationTempPath;
+ std::filesystem::path StateDir;
+ std::filesystem::path TempDir;
std::string HydrationTargetSpecification;
+ CbObject HydrationOptions;
uint32_t HttpThreadCount = 0; // Automatic
int CoreLimit = 0; // Automatic
std::filesystem::path ConfigPath;
+ std::string Malloc;
+ std::string Trace;
+ std::string TraceHost;
+ std::string TraceFile;
+
+ WorkerThreadPool* OptionalWorkerPool = nullptr;
};
StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId);
~StorageServerInstance();
- const ResourceMetrics& GetResourceMetrics() const { return m_ResourceMetrics; }
-
inline std::string_view GetModuleId() const { return m_ModuleId; }
inline uint16_t GetBasePort() const { return m_Config.BasePort; }
- void GetProcessMetrics(ProcessMetrics& OutMetrics) const;
+ ProcessMetrics GetProcessMetrics() const;
#if ZEN_PLATFORM_WINDOWS
void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; }
@@ -68,15 +77,10 @@ public:
}
bool IsRunning() const;
- const ResourceMetrics& GetResourceMetrics() const
+ ProcessMetrics GetProcessMetrics() const
{
ZEN_ASSERT(m_Instance);
- return m_Instance->m_ResourceMetrics;
- }
- void UpdateMetrics()
- {
- ZEN_ASSERT(m_Instance);
- return m_Instance->UpdateMetricsLocked();
+ return m_Instance->GetProcessMetrics();
}
#if ZEN_WITH_TESTS
@@ -114,14 +118,9 @@ public:
}
bool IsRunning() const;
- const ResourceMetrics& GetResourceMetrics() const
- {
- ZEN_ASSERT(m_Instance);
- return m_Instance->m_ResourceMetrics;
- }
-
void Provision();
void Deprovision();
+ void Obliterate();
void Hibernate();
void Wake();
@@ -135,29 +134,17 @@ public:
private:
void ProvisionLocked();
void DeprovisionLocked();
+ void ObliterateLocked();
void HibernateLocked();
void WakeLocked();
- void UpdateMetricsLocked();
-
mutable RwLock m_Lock;
const Configuration m_Config;
std::string m_ModuleId;
ZenServerInstance m_ServerInstance;
- std::filesystem::path m_BaseDir;
-
- std::filesystem::path m_TempDir;
- ResourceMetrics m_ResourceMetrics;
-
- std::atomic<uint64_t> m_MemoryBytes = 0;
- std::atomic<uint64_t> m_KernelTimeMs = 0;
- std::atomic<uint64_t> m_UserTimeMs = 0;
- std::atomic<uint64_t> m_WorkingSetSize = 0;
- std::atomic<uint64_t> m_PeakWorkingSetSize = 0;
- std::atomic<uint64_t> m_PagefileUsage = 0;
- std::atomic<uint64_t> m_PeakPagefileUsage = 0;
+ CbObject m_HydrationState;
#if ZEN_PLATFORM_WINDOWS
JobObject* m_JobObject = nullptr;
@@ -165,8 +152,9 @@ private:
void SpawnServerProcess();
- void Hydrate();
- void Dehydrate();
+ void Hydrate();
+ void Dehydrate();
+ HydrationConfig MakeHydrationConfig(std::atomic<bool>& AbortFlag, std::atomic<bool>& PauseFlag);
friend class SharedLockedPtr;
friend class ExclusiveLockedPtr;
diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp
index 314031246..1390d112e 100644
--- a/src/zenserver/hub/zenhubserver.cpp
+++ b/src/zenserver/hub/zenhubserver.cpp
@@ -2,21 +2,29 @@
#include "zenhubserver.h"
+#include "config/luaconfig.h"
#include "frontend/frontend.h"
#include "httphubservice.h"
+#include "httpproxyhandler.h"
#include "hub.h"
+#include <zencore/compactbinary.h>
#include <zencore/config.h>
+#include <zencore/except.h>
+#include <zencore/except_fmt.h>
+#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
+#include <zencore/intmath.h>
#include <zencore/memory/llm.h>
#include <zencore/memory/memorytrace.h>
#include <zencore/memory/tagtrace.h>
#include <zencore/scopeguard.h>
#include <zencore/sentryintegration.h>
+#include <zencore/system.h>
+#include <zencore/thread.h>
#include <zencore/windows.h>
#include <zenhttp/httpapiservice.h>
#include <zenutil/service.h>
-#include <zenutil/workerpools.h>
ZEN_THIRD_PARTY_INCLUDES_START
#include <cxxopts.hpp>
@@ -53,12 +61,19 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("hub",
"",
"instance-id",
- "Instance ID for use in notifications",
+ "Instance ID for use in notifications (deprecated, use --upstream-notification-instance-id)",
cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""),
"");
Options.add_option("hub",
"",
+ "upstream-notification-instance-id",
+ "Instance ID for use in notifications",
+ cxxopts::value<std::string>(m_ServerOptions.InstanceId),
+ "");
+
+ Options.add_option("hub",
+ "",
"consul-endpoint",
"Consul endpoint URL for service registration (empty = disabled)",
cxxopts::value<std::string>(m_ServerOptions.ConsulEndpoint)->default_value(""),
@@ -88,13 +103,27 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("hub",
"",
+ "consul-register-hub",
+ "Register the hub parent service with Consul (instance registration is unaffected)",
+ cxxopts::value<bool>(m_ServerOptions.ConsulRegisterHub)->default_value("true"),
+ "");
+
+ Options.add_option("hub",
+ "",
"hub-base-port-number",
- "Base port number for provisioned instances",
+ "Base port number for provisioned instances (deprecated, use --hub-instance-base-port-number)",
cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber)->default_value("21000"),
"");
Options.add_option("hub",
"",
+ "hub-instance-base-port-number",
+ "Base port number for provisioned instances",
+ cxxopts::value<uint16_t>(m_ServerOptions.HubBasePortNumber),
+ "");
+
+ Options.add_option("hub",
+ "",
"hub-instance-limit",
"Maximum number of provisioned instances for this hub",
cxxopts::value<int>(m_ServerOptions.HubInstanceLimit)->default_value("1000"),
@@ -113,6 +142,34 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("hub",
"",
+ "hub-instance-malloc",
+ "Select memory allocator for provisioned instances (ansi|stomp|rpmalloc|mimalloc)",
+ cxxopts::value<std::string>(m_ServerOptions.HubInstanceMalloc)->default_value(""),
+ "<allocator>");
+
+ Options.add_option("hub",
+ "",
+ "hub-instance-trace",
+ "Trace channel specification for provisioned instances (e.g. default, cpu,log, memory)",
+ cxxopts::value<std::string>(m_ServerOptions.HubInstanceTrace)->default_value(""),
+ "<channels>");
+
+ Options.add_option("hub",
+ "",
+ "hub-instance-tracehost",
+ "Trace host for provisioned instances",
+ cxxopts::value<std::string>(m_ServerOptions.HubInstanceTraceHost)->default_value(""),
+ "<host>");
+
+ Options.add_option("hub",
+ "",
+ "hub-instance-tracefile",
+ "Trace file path for provisioned instances",
+ cxxopts::value<std::string>(m_ServerOptions.HubInstanceTraceFile)->default_value(""),
+ "<path>");
+
+ Options.add_option("hub",
+ "",
"hub-instance-http-threads",
"Number of http server connection threads for provisioned instances",
cxxopts::value<unsigned int>(m_ServerOptions.HubInstanceHttpThreadCount),
@@ -131,6 +188,16 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
cxxopts::value(m_ServerOptions.HubInstanceConfigPath),
"<instance config>");
+ const uint32_t DefaultHubInstanceProvisionThreadCount = Max(GetHardwareConcurrency() / 4u, 2u);
+
+ Options.add_option("hub",
+ "",
+ "hub-instance-provision-threads",
+ fmt::format("Number of threads for instance provisioning (default {})", DefaultHubInstanceProvisionThreadCount),
+ cxxopts::value<uint32_t>(m_ServerOptions.HubInstanceProvisionThreadCount)
+ ->default_value(fmt::format("{}", DefaultHubInstanceProvisionThreadCount)),
+ "<threads>");
+
Options.add_option("hub",
"",
"hub-hydration-target-spec",
@@ -139,6 +206,24 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
cxxopts::value(m_ServerOptions.HydrationTargetSpecification),
"<hydration-target-spec>");
+ Options.add_option("hub",
+ "",
+ "hub-hydration-target-config",
+ "Path to JSON file specifying the hydration target (mutually exclusive with "
+ "--hub-hydration-target-spec). Supported types: 'file', 's3'.",
+ cxxopts::value(m_ServerOptions.HydrationTargetConfigPath),
+ "<path>");
+
+ const uint32_t DefaultHubHydrationThreadCount = Max(GetHardwareConcurrency() / 4u, 2u);
+
+ Options.add_option(
+ "hub",
+ "",
+ "hub-hydration-threads",
+ fmt::format("Number of threads for hydration/dehydration (default {})", DefaultHubHydrationThreadCount),
+ cxxopts::value<uint32_t>(m_ServerOptions.HubHydrationThreadCount)->default_value(fmt::format("{}", DefaultHubHydrationThreadCount)),
+ "<threads>");
+
#if ZEN_PLATFORM_WINDOWS
Options.add_option("hub",
"",
@@ -203,12 +288,112 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options)
"Request timeout in milliseconds for instance activity check requests",
cxxopts::value<uint32_t>(m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs)->default_value("200"),
"<ms>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-disk-limit-bytes",
+ "Reject provisioning when used disk bytes exceed this value (0 = no limit).",
+ cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionDiskLimitBytes),
+ "<bytes>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-disk-limit-percent",
+ "Reject provisioning when used disk exceeds this percentage of total disk (0 = no limit).",
+ cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionDiskLimitPercent),
+ "<percent>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-memory-limit-bytes",
+ "Reject provisioning when used memory bytes exceed this value (0 = no limit).",
+ cxxopts::value<uint64_t>(m_ServerOptions.HubProvisionMemoryLimitBytes),
+ "<bytes>");
+
+ Options.add_option("hub",
+ "",
+ "hub-provision-memory-limit-percent",
+ "Reject provisioning when used memory exceeds this percentage of total RAM (0 = no limit).",
+ cxxopts::value<uint32_t>(m_ServerOptions.HubProvisionMemoryLimitPercent),
+ "<percent>");
}
void
ZenHubServerConfigurator::AddConfigOptions(LuaConfig::Options& Options)
{
- ZEN_UNUSED(Options);
+ using namespace std::literals;
+
+ Options.AddOption("hub.upstreamnotification.endpoint"sv,
+ m_ServerOptions.UpstreamNotificationEndpoint,
+ "upstream-notification-endpoint"sv);
+ Options.AddOption("hub.upstreamnotification.instanceid"sv, m_ServerOptions.InstanceId, "upstream-notification-instance-id"sv);
+
+ Options.AddOption("hub.consul.endpoint"sv, m_ServerOptions.ConsulEndpoint, "consul-endpoint"sv);
+ Options.AddOption("hub.consul.tokenenv"sv, m_ServerOptions.ConsulTokenEnv, "consul-token-env"sv);
+ Options.AddOption("hub.consul.healthintervalseconds"sv,
+ m_ServerOptions.ConsulHealthIntervalSeconds,
+ "consul-health-interval-seconds"sv);
+ Options.AddOption("hub.consul.deregisterafterseconds"sv,
+ m_ServerOptions.ConsulDeregisterAfterSeconds,
+ "consul-deregister-after-seconds"sv);
+ Options.AddOption("hub.consul.registerhub"sv, m_ServerOptions.ConsulRegisterHub, "consul-register-hub"sv);
+
+ Options.AddOption("hub.instance.baseportnumber"sv, m_ServerOptions.HubBasePortNumber, "hub-instance-base-port-number"sv);
+ Options.AddOption("hub.instance.http"sv, m_ServerOptions.HubInstanceHttpClass, "hub-instance-http"sv);
+ Options.AddOption("hub.instance.malloc"sv, m_ServerOptions.HubInstanceMalloc, "hub-instance-malloc"sv);
+ Options.AddOption("hub.instance.trace"sv, m_ServerOptions.HubInstanceTrace, "hub-instance-trace"sv);
+ Options.AddOption("hub.instance.tracehost"sv, m_ServerOptions.HubInstanceTraceHost, "hub-instance-tracehost"sv);
+ Options.AddOption("hub.instance.tracefile"sv, m_ServerOptions.HubInstanceTraceFile, "hub-instance-tracefile"sv);
+ Options.AddOption("hub.instance.httpthreads"sv, m_ServerOptions.HubInstanceHttpThreadCount, "hub-instance-http-threads"sv);
+ Options.AddOption("hub.instance.corelimit"sv, m_ServerOptions.HubInstanceCoreLimit, "hub-instance-corelimit"sv);
+ Options.AddOption("hub.instance.config"sv, m_ServerOptions.HubInstanceConfigPath, "hub-instance-config"sv);
+ Options.AddOption("hub.instance.limits.count"sv, m_ServerOptions.HubInstanceLimit, "hub-instance-limit"sv);
+ Options.AddOption("hub.instance.limits.disklimitbytes"sv,
+ m_ServerOptions.HubProvisionDiskLimitBytes,
+ "hub-provision-disk-limit-bytes"sv);
+ Options.AddOption("hub.instance.limits.disklimitpercent"sv,
+ m_ServerOptions.HubProvisionDiskLimitPercent,
+ "hub-provision-disk-limit-percent"sv);
+ Options.AddOption("hub.instance.limits.memorylimitbytes"sv,
+ m_ServerOptions.HubProvisionMemoryLimitBytes,
+ "hub-provision-memory-limit-bytes"sv);
+ Options.AddOption("hub.instance.limits.memorylimitpercent"sv,
+ m_ServerOptions.HubProvisionMemoryLimitPercent,
+ "hub-provision-memory-limit-percent"sv);
+ Options.AddOption("hub.instance.provisionthreads"sv,
+ m_ServerOptions.HubInstanceProvisionThreadCount,
+ "hub-instance-provision-threads"sv);
+
+ Options.AddOption("hub.hydration.targetspec"sv, m_ServerOptions.HydrationTargetSpecification, "hub-hydration-target-spec"sv);
+ Options.AddOption("hub.hydration.targetconfig"sv, m_ServerOptions.HydrationTargetConfigPath, "hub-hydration-target-config"sv);
+ Options.AddOption("hub.hydration.threads"sv, m_ServerOptions.HubHydrationThreadCount, "hub-hydration-threads"sv);
+
+ Options.AddOption("hub.watchdog.cycleintervalms"sv, m_ServerOptions.WatchdogConfig.CycleIntervalMs, "hub-watchdog-cycle-interval-ms"sv);
+ Options.AddOption("hub.watchdog.cycleprocessingbudgetms"sv,
+ m_ServerOptions.WatchdogConfig.CycleProcessingBudgetMs,
+ "hub-watchdog-cycle-processing-budget-ms"sv);
+ Options.AddOption("hub.watchdog.instancecheckthrottlems"sv,
+ m_ServerOptions.WatchdogConfig.InstanceCheckThrottleMs,
+ "hub-watchdog-instance-check-throttle-ms"sv);
+ Options.AddOption("hub.watchdog.provisionedinactivitytimeoutseconds"sv,
+ m_ServerOptions.WatchdogConfig.ProvisionedInactivityTimeoutSeconds,
+ "hub-watchdog-provisioned-inactivity-timeout-seconds"sv);
+ Options.AddOption("hub.watchdog.hibernatedinactivitytimeoutseconds"sv,
+ m_ServerOptions.WatchdogConfig.HibernatedInactivityTimeoutSeconds,
+ "hub-watchdog-hibernated-inactivity-timeout-seconds"sv);
+ Options.AddOption("hub.watchdog.inactivitycheckmarginseconds"sv,
+ m_ServerOptions.WatchdogConfig.InactivityCheckMarginSeconds,
+ "hub-watchdog-inactivity-check-margin-seconds"sv);
+ Options.AddOption("hub.watchdog.activitycheckconnecttimeoutms"sv,
+ m_ServerOptions.WatchdogConfig.ActivityCheckConnectTimeoutMs,
+ "hub-watchdog-activity-check-connect-timeout-ms"sv);
+ Options.AddOption("hub.watchdog.activitycheckrequesttimeoutms"sv,
+ m_ServerOptions.WatchdogConfig.ActivityCheckRequestTimeoutMs,
+ "hub-watchdog-activity-check-request-timeout-ms"sv);
+
+#if ZEN_PLATFORM_WINDOWS
+ Options.AddOption("hub.usejobobject"sv, m_ServerOptions.HubUseJobObject, "hub-use-job-object"sv);
+#endif
}
void
@@ -226,6 +411,28 @@ ZenHubServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions)
void
ZenHubServerConfigurator::ValidateOptions()
{
+ if (m_ServerOptions.HubProvisionDiskLimitPercent > 100)
+ {
+ throw OptionParseException(
+ fmt::format("'--hub-provision-disk-limit-percent' ({}) must be in range 0..100", m_ServerOptions.HubProvisionDiskLimitPercent),
+ {});
+ }
+ if (m_ServerOptions.HubProvisionMemoryLimitPercent > 100)
+ {
+ throw OptionParseException(fmt::format("'--hub-provision-memory-limit-percent' ({}) must be in range 0..100",
+ m_ServerOptions.HubProvisionMemoryLimitPercent),
+ {});
+ }
+ if (!m_ServerOptions.HydrationTargetSpecification.empty() && !m_ServerOptions.HydrationTargetConfigPath.empty())
+ {
+ throw OptionParseException("'--hub-hydration-target-spec' and '--hub-hydration-target-config' are mutually exclusive", {});
+ }
+ if (!m_ServerOptions.HydrationTargetConfigPath.empty() && !std::filesystem::exists(m_ServerOptions.HydrationTargetConfigPath))
+ {
+ throw OptionParseException(
+ fmt::format("'--hub-hydration-target-config': file not found: '{}'", m_ServerOptions.HydrationTargetConfigPath.string()),
+ {});
+ }
}
///////////////////////////////////////////////////////////////////////////
@@ -247,6 +454,15 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId,
HubInstanceState NewState)
{
ZEN_UNUSED(PreviousState);
+
+ if (NewState == HubInstanceState::Deprovisioning || NewState == HubInstanceState::Hibernating)
+ {
+ if (Info.Port != 0)
+ {
+ m_Proxy->PrunePort(Info.Port);
+ }
+ }
+
if (!m_ConsulClient)
{
return;
@@ -262,12 +478,9 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId,
.Tags = std::vector<std::pair<std::string, std::string>>{std::make_pair("module", std::string(ModuleId)),
std::make_pair("zen-hub", std::string(HubInstanceId)),
std::make_pair("version", std::string(ZEN_CFG_VERSION))},
- .HealthIntervalSeconds = NewState == HubInstanceState::Provisioning
- ? 0u
- : m_ConsulHealthIntervalSeconds, // Disable health checks while not finished provisioning
- .DeregisterAfterSeconds = NewState == HubInstanceState::Provisioning
- ? 0u
- : m_ConsulDeregisterAfterSeconds}; // Disable health checks while not finished provisioning
+ .HealthIntervalSeconds = NewState == HubInstanceState::Provisioning ? 0u : m_ConsulHealthIntervalSeconds,
+ .DeregisterAfterSeconds = NewState == HubInstanceState::Provisioning ? 0u : m_ConsulDeregisterAfterSeconds,
+ .InitialStatus = NewState == HubInstanceState::Provisioned ? "passing" : ""};
if (!m_ConsulClient->RegisterService(ServiceInfo))
{
@@ -294,8 +507,8 @@ ZenHubServer::OnModuleStateChanged(std::string_view HubInstanceId,
ZEN_INFO("Deregistered storage server instance for module '{}' at port {} from Consul", ModuleId, Info.Port);
}
}
- // Transitional states (Deprovisioning, Hibernating, Waking, Recovering, Crashed)
- // and Hibernated are intentionally ignored.
+ // Transitional states (Waking, Recovering, Crashed) and stable states
+ // not handled above (Hibernated) are intentionally ignored by Consul.
}
int
@@ -317,6 +530,10 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState:
// the main test range.
ZenServerEnvironment::SetBaseChildId(1000);
+ m_ProvisionWorkerPool =
+ std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubInstanceProvisionThreadCount), "hub_provision");
+ m_HydrationWorkerPool = std::make_unique<WorkerThreadPool>(gsl::narrow<int>(ServerConfig.HubHydrationThreadCount), "hub_hydration");
+
m_DebugOptionForcedCrash = ServerConfig.ShouldCrash;
InitializeState(ServerConfig);
@@ -342,12 +559,18 @@ ZenHubServer::Cleanup()
m_IoRunner.join();
}
- ShutdownServices();
if (m_Http)
{
m_Http->Close();
}
+ ShutdownServices();
+
+ if (m_Proxy)
+ {
+ m_Proxy->Shutdown();
+ }
+
if (m_Hub)
{
m_Hub->Shutdown();
@@ -357,6 +580,7 @@ ZenHubServer::Cleanup()
m_HubService.reset();
m_ApiService.reset();
m_Hub.reset();
+ m_Proxy.reset();
m_ConsulRegistration.reset();
m_ConsulClient.reset();
@@ -373,49 +597,121 @@ ZenHubServer::InitializeState(const ZenHubServerConfig& ServerConfig)
ZEN_UNUSED(ServerConfig);
}
+ResourceMetrics
+ZenHubServer::ResolveLimits(const ZenHubServerConfig& ServerConfig)
+{
+ uint64_t DiskTotal = 0;
+ uint64_t MemoryTotal = 0;
+
+ if (ServerConfig.HubProvisionDiskLimitPercent > 0)
+ {
+ DiskSpace Disk;
+ if (DiskSpaceInfo(ServerConfig.DataDir, Disk))
+ {
+ DiskTotal = Disk.Total;
+ }
+ else
+ {
+ ZEN_WARN("Failed to query disk space for '{}'; disk percent limit will not be applied", ServerConfig.DataDir);
+ }
+ }
+ if (ServerConfig.HubProvisionMemoryLimitPercent > 0)
+ {
+ MemoryTotal = GetSystemMetrics().SystemMemoryMiB * 1024 * 1024;
+ }
+
+ auto Resolve = [](uint64_t Bytes, uint32_t Pct, uint64_t Total) -> uint64_t {
+ const uint64_t PctBytes = Pct > 0 ? (Total * Pct) / 100 : 0;
+ if (Bytes > 0 && PctBytes > 0)
+ {
+ return Min(Bytes, PctBytes);
+ }
+ return Bytes > 0 ? Bytes : PctBytes;
+ };
+
+ return {
+ .DiskUsageBytes = Resolve(ServerConfig.HubProvisionDiskLimitBytes, ServerConfig.HubProvisionDiskLimitPercent, DiskTotal),
+ .MemoryUsageBytes = Resolve(ServerConfig.HubProvisionMemoryLimitBytes, ServerConfig.HubProvisionMemoryLimitPercent, MemoryTotal),
+ };
+}
+
void
ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig)
{
ZEN_INFO("instantiating Hub");
+ Hub::Configuration HubConfig{
+ .UseJobObject = ServerConfig.HubUseJobObject,
+ .BasePortNumber = ServerConfig.HubBasePortNumber,
+ .InstanceLimit = ServerConfig.HubInstanceLimit,
+ .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount,
+ .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit,
+ .InstanceMalloc = ServerConfig.HubInstanceMalloc,
+ .InstanceTrace = ServerConfig.HubInstanceTrace,
+ .InstanceTraceHost = ServerConfig.HubInstanceTraceHost,
+ .InstanceTraceFile = ServerConfig.HubInstanceTraceFile,
+ .InstanceConfigPath = ServerConfig.HubInstanceConfigPath,
+ .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification,
+ .WatchDog =
+ {
+ .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs),
+ .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs),
+ .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs),
+ .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds),
+ .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds),
+ .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds),
+ .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs),
+ .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs),
+ },
+ .ResourceLimits = ResolveLimits(ServerConfig),
+ .OptionalProvisionWorkerPool = m_ProvisionWorkerPool.get(),
+ .OptionalHydrationWorkerPool = m_HydrationWorkerPool.get()};
+
+ if (!ServerConfig.HydrationTargetConfigPath.empty())
+ {
+ FileContents Contents = ReadFile(ServerConfig.HydrationTargetConfigPath);
+ if (!Contents)
+ {
+ throw zen::runtime_error("Failed to read hydration config '{}': {}",
+ ServerConfig.HydrationTargetConfigPath.string(),
+ Contents.ErrorCode.message());
+ }
+ IoBuffer Buffer(Contents.Flatten());
+ std::string_view JsonText(static_cast<const char*>(Buffer.GetData()), Buffer.GetSize());
+
+ std::string ParseError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(JsonText, ParseError);
+ if (!ParseError.empty() || !Root.IsObject())
+ {
+ throw zen::runtime_error("Failed to parse hydration config '{}': {}",
+ ServerConfig.HydrationTargetConfigPath.string(),
+ ParseError.empty() ? "root must be a JSON object" : ParseError);
+ }
+ HubConfig.HydrationOptions = std::move(Root).AsObject();
+ }
+
+ m_Proxy = std::make_unique<HttpProxyHandler>();
+
m_Hub = std::make_unique<Hub>(
- Hub::Configuration{
- .UseJobObject = ServerConfig.HubUseJobObject,
- .BasePortNumber = ServerConfig.HubBasePortNumber,
- .InstanceLimit = ServerConfig.HubInstanceLimit,
- .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount,
- .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit,
- .InstanceConfigPath = ServerConfig.HubInstanceConfigPath,
- .HydrationTargetSpecification = ServerConfig.HydrationTargetSpecification,
- .WatchDog =
- {
- .CycleInterval = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleIntervalMs),
- .CycleProcessingBudget = std::chrono::milliseconds(ServerConfig.WatchdogConfig.CycleProcessingBudgetMs),
- .InstanceCheckThrottle = std::chrono::milliseconds(ServerConfig.WatchdogConfig.InstanceCheckThrottleMs),
- .ProvisionedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.ProvisionedInactivityTimeoutSeconds),
- .HibernatedInactivityTimeout = std::chrono::seconds(ServerConfig.WatchdogConfig.HibernatedInactivityTimeoutSeconds),
- .InactivityCheckMargin = std::chrono::seconds(ServerConfig.WatchdogConfig.InactivityCheckMarginSeconds),
- .ActivityCheckConnectTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckConnectTimeoutMs),
- .ActivityCheckRequestTimeout = std::chrono::milliseconds(ServerConfig.WatchdogConfig.ActivityCheckRequestTimeoutMs),
- }},
+ std::move(HubConfig),
ZenServerEnvironment(ZenServerEnvironment::Hub,
ServerConfig.DataDir / "hub",
ServerConfig.DataDir / "servers",
ServerConfig.HubInstanceHttpClass),
- &GetMediumWorkerPool(EWorkloadType::Background),
- m_ConsulClient ? Hub::AsyncModuleStateChangeCallbackFunc{[this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](
- std::string_view ModuleId,
- const HubProvisionedInstanceInfo& Info,
- HubInstanceState PreviousState,
- HubInstanceState NewState) {
- OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState);
- }}
- : Hub::AsyncModuleStateChangeCallbackFunc{});
+ Hub::AsyncModuleStateChangeCallbackFunc{
+ [this, HubInstanceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId)](std::string_view ModuleId,
+ const HubProvisionedInstanceInfo& Info,
+ HubInstanceState PreviousState,
+ HubInstanceState NewState) {
+ OnModuleStateChanged(HubInstanceId, ModuleId, Info, PreviousState, NewState);
+ }});
+
+ m_Proxy->SetPortValidator([Hub = m_Hub.get()](uint16_t Port) { return Hub->IsInstancePort(Port); });
ZEN_INFO("instantiating API service");
m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http);
ZEN_INFO("instantiating hub service");
- m_HubService = std::make_unique<HttpHubService>(*m_Hub, m_StatsService, m_StatusService);
+ m_HubService = std::make_unique<HttpHubService>(*m_Hub, *m_Proxy, m_StatsService, m_StatusService);
m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId);
m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService);
@@ -465,21 +761,32 @@ ZenHubServer::InitializeConsulRegistration(const ZenHubServerConfig& ServerConfi
}
else
{
- ZEN_INFO("Consul token read from environment variable '{}'", ConsulAccessTokenEnvName);
+ ZEN_INFO("Consul token will be read from environment variable '{}'", ConsulAccessTokenEnvName);
}
try
{
- m_ConsulClient = std::make_unique<consul::ConsulClient>(ServerConfig.ConsulEndpoint, ConsulAccessToken);
+ m_ConsulClient = std::make_unique<consul::ConsulClient>(consul::ConsulClient::Configuration{
+ .BaseUri = ServerConfig.ConsulEndpoint,
+ .TokenEnvName = ConsulAccessTokenEnvName,
+ });
m_ConsulHealthIntervalSeconds = ServerConfig.ConsulHealthIntervalSeconds;
m_ConsulDeregisterAfterSeconds = ServerConfig.ConsulDeregisterAfterSeconds;
+ if (!ServerConfig.ConsulRegisterHub)
+ {
+ ZEN_INFO(
+ "Hub parent Consul registration skipped (consul-register-hub is false); "
+ "instance registration remains enabled");
+ return;
+ }
+
consul::ServiceRegistrationInfo Info;
Info.ServiceId = fmt::format("zen-hub-{}", ServerConfig.InstanceId);
Info.ServiceName = "zen-hub";
// Info.Address = "localhost"; // Let the consul agent figure out out external address // TODO: Info.BaseUri?
Info.Port = static_cast<uint16_t>(EffectivePort);
- Info.HealthEndpoint = "hub/health";
+ Info.HealthEndpoint = "health";
Info.Tags = std::vector<std::pair<std::string, std::string>>{
std::make_pair("zen-hub", Info.ServiceId),
std::make_pair("version", std::string(ZEN_CFG_VERSION)),
@@ -569,6 +876,8 @@ ZenHubServer::Run()
OnReady();
+ StartSelfSession("zenhub");
+
m_Http->Run(IsInteractiveMode);
SetNewState(kShuttingDown);
diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h
index 77df3eaa3..5e465bb14 100644
--- a/src/zenserver/hub/zenhubserver.h
+++ b/src/zenserver/hub/zenhubserver.h
@@ -3,8 +3,10 @@
#pragma once
#include "hubinstancestate.h"
+#include "resourcemetrics.h"
#include "zenserver.h"
+#include <zencore/workthreadpool.h>
#include <zenutil/consul.h>
namespace cxxopts {
@@ -19,6 +21,7 @@ namespace zen {
class HttpApiService;
class HttpFrontendService;
class HttpHubService;
+class HttpProxyHandler;
struct ZenHubWatchdogConfig
{
@@ -34,21 +37,33 @@ struct ZenHubWatchdogConfig
struct ZenHubServerConfig : public ZenServerConfig
{
- std::string UpstreamNotificationEndpoint;
- std::string InstanceId; // For use in notifications
- std::string ConsulEndpoint; // If set, enables Consul service registration
- std::string ConsulTokenEnv; // Environment variable name to read a Consul token from; defaults to CONSUL_HTTP_TOKEN if empty
- uint32_t ConsulHealthIntervalSeconds = 10; // Interval in seconds between Consul health checks
- uint32_t ConsulDeregisterAfterSeconds = 30; // Seconds before Consul deregisters an unhealthy service
- uint16_t HubBasePortNumber = 21000;
- int HubInstanceLimit = 1000;
- bool HubUseJobObject = true;
- std::string HubInstanceHttpClass = "asio";
- uint32_t HubInstanceHttpThreadCount = 0; // Automatic
- int HubInstanceCoreLimit = 0; // Automatic
- std::filesystem::path HubInstanceConfigPath; // Path to Lua config file
- std::string HydrationTargetSpecification; // hydration/dehydration target specification
+ std::string UpstreamNotificationEndpoint;
+ std::string InstanceId; // For use in notifications
+ std::string ConsulEndpoint; // If set, enables Consul service registration
+ std::string ConsulTokenEnv; // Environment variable name to read a Consul token from; defaults to CONSUL_HTTP_TOKEN if empty
+ uint32_t ConsulHealthIntervalSeconds = 10; // Interval in seconds between Consul health checks
+ uint32_t ConsulDeregisterAfterSeconds = 30; // Seconds before Consul deregisters an unhealthy service
+ bool ConsulRegisterHub = true; // Whether to register the hub parent service with Consul (instance registration unaffected)
+ uint16_t HubBasePortNumber = 21000;
+ int HubInstanceLimit = 1000;
+ bool HubUseJobObject = true;
+ std::string HubInstanceHttpClass = "asio";
+ std::string HubInstanceMalloc;
+ std::string HubInstanceTrace;
+ std::string HubInstanceTraceHost;
+ std::string HubInstanceTraceFile;
+ uint32_t HubInstanceHttpThreadCount = 0; // Automatic
+ uint32_t HubInstanceProvisionThreadCount = 0; // Synchronous provisioning
+ uint32_t HubHydrationThreadCount = 0; // Synchronous hydration/dehydration
+ int HubInstanceCoreLimit = 0; // Automatic
+ std::filesystem::path HubInstanceConfigPath; // Path to Lua config file
+ std::string HydrationTargetSpecification; // hydration/dehydration target specification
+ std::filesystem::path HydrationTargetConfigPath; // path to JSON config file (mutually exclusive with HydrationTargetSpecification)
ZenHubWatchdogConfig WatchdogConfig;
+ uint64_t HubProvisionDiskLimitBytes = 0;
+ uint32_t HubProvisionDiskLimitPercent = 0;
+ uint64_t HubProvisionMemoryLimitBytes = 0;
+ uint32_t HubProvisionMemoryLimitPercent = 0;
};
class Hub;
@@ -115,7 +130,10 @@ private:
std::filesystem::path m_ContentRoot;
bool m_DebugOptionForcedCrash = false;
- std::unique_ptr<Hub> m_Hub;
+ std::unique_ptr<HttpProxyHandler> m_Proxy;
+ std::unique_ptr<WorkerThreadPool> m_ProvisionWorkerPool;
+ std::unique_ptr<WorkerThreadPool> m_HydrationWorkerPool;
+ std::unique_ptr<Hub> m_Hub;
std::unique_ptr<HttpHubService> m_HubService;
std::unique_ptr<HttpApiService> m_ApiService;
@@ -126,6 +144,8 @@ private:
uint32_t m_ConsulHealthIntervalSeconds = 10;
uint32_t m_ConsulDeregisterAfterSeconds = 30;
+ static ResourceMetrics ResolveLimits(const ZenHubServerConfig& ServerConfig);
+
void InitializeState(const ZenHubServerConfig& ServerConfig);
void InitializeServices(const ZenHubServerConfig& ServerConfig);
void RegisterServices(const ZenHubServerConfig& ServerConfig);
diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp
index 00b7a67d7..108685eb9 100644
--- a/src/zenserver/main.cpp
+++ b/src/zenserver/main.cpp
@@ -14,7 +14,6 @@
#include <zencore/memory/memorytrace.h>
#include <zencore/memory/newdelete.h>
#include <zencore/scopeguard.h>
-#include <zencore/sentryintegration.h>
#include <zencore/session.h>
#include <zencore/string.h>
#include <zencore/thread.h>
@@ -169,7 +168,12 @@ AppMain(int argc, char* argv[])
if (IsDir(ServerOptions.DataDir))
{
ZEN_CONSOLE_INFO("Deleting files from '{}' ({})", ServerOptions.DataDir, DeleteReason);
- DeleteDirectories(ServerOptions.DataDir);
+ std::error_code Ec;
+ DeleteDirectories(ServerOptions.DataDir, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("could not fully clean '{}': {} (continuing anyway)", ServerOptions.DataDir, Ec.message());
+ }
}
}
@@ -250,7 +254,7 @@ test_main(int argc, char** argv)
zen::MaximizeOpenFileCount();
zen::testing::TestRunner Runner;
- Runner.ApplyCommandLine(argc, argv);
+ Runner.ApplyCommandLine(argc, argv, "server.*");
return Runner.Run();
}
#endif
diff --git a/src/zenserver/proxy/httptrafficinspector.cpp b/src/zenserver/proxy/httptrafficinspector.cpp
index 74ecbfd48..913bd2c28 100644
--- a/src/zenserver/proxy/httptrafficinspector.cpp
+++ b/src/zenserver/proxy/httptrafficinspector.cpp
@@ -10,29 +10,33 @@
namespace zen {
// clang-format off
-http_parser_settings HttpTrafficInspector::s_RequestSettings{
- .on_message_begin = [](http_parser*) { return 0; },
- .on_url = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnUrl(Data, Len); },
- .on_status = [](http_parser*, const char*, size_t) { return 0; },
- .on_header_field = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); },
- .on_header_value = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); },
- .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
- .on_body = [](http_parser*, const char*, size_t) { return 0; },
- .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
- .on_chunk_header{},
- .on_chunk_complete{}};
-
-http_parser_settings HttpTrafficInspector::s_ResponseSettings{
- .on_message_begin = [](http_parser*) { return 0; },
- .on_url = [](http_parser*, const char*, size_t) { return 0; },
- .on_status = [](http_parser*, const char*, size_t) { return 0; },
- .on_header_field = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); },
- .on_header_value = [](http_parser* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); },
- .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
- .on_body = [](http_parser*, const char*, size_t) { return 0; },
- .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
- .on_chunk_header{},
- .on_chunk_complete{}};
+llhttp_settings_t HttpTrafficInspector::s_RequestSettings = []() {
+ llhttp_settings_t S;
+ llhttp_settings_init(&S);
+ S.on_message_begin = [](llhttp_t*) { return 0; };
+ S.on_url = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnUrl(Data, Len); };
+ S.on_status = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_header_field = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); };
+ S.on_header_value = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); };
+ S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); };
+ S.on_body = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); };
+ return S;
+}();
+
+llhttp_settings_t HttpTrafficInspector::s_ResponseSettings = []() {
+ llhttp_settings_t S;
+ llhttp_settings_init(&S);
+ S.on_message_begin = [](llhttp_t*) { return 0; };
+ S.on_url = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_status = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_header_field = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderField(Data, Len); };
+ S.on_header_value = [](llhttp_t* p, const char* Data, size_t Len) { return GetThis(p)->OnHeaderValue(Data, Len); };
+ S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); };
+ S.on_body = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); };
+ return S;
+}();
// clang-format on
HttpTrafficInspector::HttpTrafficInspector(Direction Dir, std::string_view SessionLabel)
@@ -40,7 +44,8 @@ HttpTrafficInspector::HttpTrafficInspector(Direction Dir, std::string_view Sessi
, m_Direction(Dir)
, m_SessionLabel(SessionLabel)
{
- http_parser_init(&m_Parser, Dir == Direction::Request ? HTTP_REQUEST : HTTP_RESPONSE);
+ llhttp_settings_t* Settings = (Dir == Direction::Request) ? &s_RequestSettings : &s_ResponseSettings;
+ llhttp_init(&m_Parser, Dir == Direction::Request ? HTTP_REQUEST : HTTP_RESPONSE, Settings);
m_Parser.data = this;
}
@@ -52,11 +57,9 @@ HttpTrafficInspector::Inspect(const char* Data, size_t Length)
return;
}
- http_parser_settings* Settings = (m_Direction == Direction::Request) ? &s_RequestSettings : &s_ResponseSettings;
+ llhttp_errno_t Err = llhttp_execute(&m_Parser, Data, Length);
- size_t Parsed = http_parser_execute(&m_Parser, Settings, Data, Length);
-
- if (m_Parser.upgrade)
+ if (Err == HPE_PAUSED_UPGRADE)
{
if (m_Direction == Direction::Request)
{
@@ -72,15 +75,9 @@ HttpTrafficInspector::Inspect(const char* Data, size_t Length)
return;
}
- http_errno Error = HTTP_PARSER_ERRNO(&m_Parser);
- if (Error != HPE_OK)
- {
- ZEN_DEBUG("[{}] non-HTTP traffic detected ({}), disabling inspection", m_SessionLabel, http_errno_name(Error));
- m_Disabled = true;
- }
- else if (Parsed != Length)
+ if (Err != HPE_OK)
{
- ZEN_DEBUG("[{}] parser consumed {}/{} bytes, disabling inspection", m_SessionLabel, Parsed, Length);
+ ZEN_DEBUG("[{}] non-HTTP traffic detected ({}), disabling inspection", m_SessionLabel, llhttp_errno_name(Err));
m_Disabled = true;
}
}
@@ -127,11 +124,11 @@ HttpTrafficInspector::OnHeadersComplete()
{
if (m_Direction == Direction::Request)
{
- m_Method = http_method_str(static_cast<http_method>(m_Parser.method));
+ m_Method = llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser)));
}
else
{
- m_StatusCode = m_Parser.status_code;
+ m_StatusCode = static_cast<uint16_t>(llhttp_get_status_code(&m_Parser));
}
return 0;
}
diff --git a/src/zenserver/proxy/httptrafficinspector.h b/src/zenserver/proxy/httptrafficinspector.h
index f4af0e77e..8192632ba 100644
--- a/src/zenserver/proxy/httptrafficinspector.h
+++ b/src/zenserver/proxy/httptrafficinspector.h
@@ -6,7 +6,7 @@
#include <zencore/uid.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <http_parser.h>
+#include <llhttp.h>
ZEN_THIRD_PARTY_INCLUDES_END
#include <atomic>
@@ -45,15 +45,15 @@ private:
void ResetMessageState();
- static HttpTrafficInspector* GetThis(http_parser* Parser) { return static_cast<HttpTrafficInspector*>(Parser->data); }
+ static HttpTrafficInspector* GetThis(llhttp_t* Parser) { return static_cast<HttpTrafficInspector*>(Parser->data); }
- static http_parser_settings s_RequestSettings;
- static http_parser_settings s_ResponseSettings;
+ static llhttp_settings_t s_RequestSettings;
+ static llhttp_settings_t s_ResponseSettings;
LoggerRef Log() { return m_Log; }
LoggerRef m_Log;
- http_parser m_Parser;
+ llhttp_t m_Parser;
Direction m_Direction;
std::string m_SessionLabel;
bool m_Disabled = false;
diff --git a/src/zenserver/proxy/zenproxyserver.cpp b/src/zenserver/proxy/zenproxyserver.cpp
index 7e59a7b7e..ffa9a4295 100644
--- a/src/zenserver/proxy/zenproxyserver.cpp
+++ b/src/zenserver/proxy/zenproxyserver.cpp
@@ -257,7 +257,7 @@ ZenProxyServerConfigurator::ValidateOptions()
for (const std::string& Raw : m_RawProxyMappings)
{
// The mode keyword "proxy" from argv[1] gets captured as a positional
- // argument — skip it.
+ // argument - skip it.
if (Raw == "proxy")
{
continue;
@@ -304,7 +304,7 @@ ZenProxyServer::Initialize(const ZenProxyServerConfig& ServerConfig, ZenServerSt
// worker threads don't exit prematurely between async operations.
m_ProxyIoWorkGuard.emplace(m_ProxyIoContext.get_executor());
- // Start proxy I/O worker threads. Use a modest thread count — proxy work is
+ // Start proxy I/O worker threads. Use a modest thread count - proxy work is
// I/O-bound so we don't need a thread per core, but having more than one
// avoids head-of-line blocking when many connections are active.
unsigned int ThreadCount = std::max(GetHardwareConcurrency() / 4, 4u);
@@ -385,6 +385,8 @@ ZenProxyServer::Run()
OnReady();
+ StartSelfSession("zenproxy");
+
m_Http->Run(IsInteractiveMode);
SetNewState(kShuttingDown);
@@ -422,15 +424,16 @@ ZenProxyServer::Cleanup()
m_IoRunner.join();
}
- m_ProxyStatsService.reset();
- m_FrontendService.reset();
- m_ApiService.reset();
-
- ShutdownServices();
if (m_Http)
{
m_Http->Close();
}
+
+ ShutdownServices();
+
+ m_ProxyStatsService.reset();
+ m_FrontendService.reset();
+ m_ApiService.reset();
}
catch (const std::exception& Ex)
{
diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp
index fdf2e1f21..56a22fb04 100644
--- a/src/zenserver/sessions/httpsessions.cpp
+++ b/src/zenserver/sessions/httpsessions.cpp
@@ -377,7 +377,7 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req)
if (ServerRequest.RequestContentType() == HttpContentType::kText)
{
- // Raw text — split by newlines, one entry per line
+ // Raw text - split by newlines, one entry per line
IoBuffer Payload = ServerRequest.ReadPayload();
std::string_view Text(reinterpret_cast<const char*>(Payload.GetData()), Payload.GetSize());
const DateTime Now = DateTime::Now();
@@ -512,8 +512,9 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req)
//
void
-HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpSessionsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_INFO("Sessions WebSocket client connected");
m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
diff --git a/src/zenserver/sessions/httpsessions.h b/src/zenserver/sessions/httpsessions.h
index 86a23f835..6ebe61c8d 100644
--- a/src/zenserver/sessions/httpsessions.h
+++ b/src/zenserver/sessions/httpsessions.h
@@ -37,7 +37,7 @@ public:
void SetSelfSessionId(const Oid& Id) { m_SelfSessionId = Id; }
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zenserver/sessions/sessions.cpp b/src/zenserver/sessions/sessions.cpp
index 1212ba5d8..9d4e3120c 100644
--- a/src/zenserver/sessions/sessions.cpp
+++ b/src/zenserver/sessions/sessions.cpp
@@ -129,7 +129,7 @@ SessionsService::~SessionsService() = default;
bool
SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, std::string Mode, const Oid& JobId, CbObjectView Metadata)
{
- // Log outside the lock scope — InProcSessionLogSink calls back into
+ // Log outside the lock scope - InProcSessionLogSink calls back into
// GetSession() which acquires m_Lock shared, so logging while holding
// m_Lock exclusively would deadlock.
{
diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp
index bbbb0c37b..f935e2c6b 100644
--- a/src/zenserver/storage/buildstore/httpbuildstore.cpp
+++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp
@@ -162,96 +162,81 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req)
fmt::format("Invalid blob hash '{}'", Hash));
}
- std::vector<std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs;
+ m_BuildStoreStats.BlobReadCount++;
+ IoBuffer Blob = m_BuildStore.GetBlob(BlobHash);
+ if (!Blob)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, fmt::format("Blob {} not found", Hash));
+ }
+ m_BuildStoreStats.BlobHitCount++;
+
if (ServerRequest.RequestVerb() == HttpVerb::kPost)
{
+ if (ServerRequest.AcceptContentType() != HttpContentType::kCbPackage)
+ {
+ m_BuildStoreStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Accept type '{}' is not supported for blob {}, expected '{}'",
+ ToString(ServerRequest.AcceptContentType()),
+ Hash,
+ ToString(HttpContentType::kCbPackage)));
+ }
+
CbObject RangePayload = ServerRequest.ReadPayloadObject();
- if (RangePayload)
+ if (!RangePayload)
{
- CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView();
- OffsetAndLengthPairs.reserve(RangesArray.Num());
- for (CbFieldView FieldView : RangesArray)
- {
- CbObjectView RangeView = FieldView.AsObjectView();
- uint64_t RangeOffset = RangeView["offset"sv].AsUInt64();
- uint64_t RangeLength = RangeView["length"sv].AsUInt64();
- OffsetAndLengthPairs.push_back(std::make_pair(RangeOffset, RangeLength));
- }
- if (OffsetAndLengthPairs.size() > MaxRangeCountPerRequestSupported)
- {
- return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
- HttpContentType::kText,
- fmt::format("Number of ranges ({}) for blob request exceeds maximum range count {}",
- OffsetAndLengthPairs.size(),
- MaxRangeCountPerRequestSupported));
- }
+ m_BuildStoreStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Missing payload for range request on blob {}", BlobHash));
}
- if (OffsetAndLengthPairs.empty())
+
+ CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView();
+ const uint64_t RangeCount = RangesArray.Num();
+ if (RangeCount == 0)
{
m_BuildStoreStats.BadRequestCount++;
return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
HttpContentType::kText,
- "Fetching blob without ranges must be done with the GET verb");
+ "POST request must include a non-empty 'ranges' array");
}
- }
- else
- {
- HttpRanges Ranges;
- bool HasRange = ServerRequest.TryGetRanges(Ranges);
- if (HasRange)
+ if (RangeCount > MaxRangeCountPerRequestSupported)
{
- if (Ranges.size() > 1)
- {
- // Only a single http range is supported, we have limited support for http multirange responses
- m_BuildStoreStats.BadRequestCount++;
- return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
- HttpContentType::kText,
- fmt::format("Multiple ranges in blob request is only supported for {} accept type",
- ToString(HttpContentType::kCbPackage)));
- }
- const HttpRange& FirstRange = Ranges.front();
- OffsetAndLengthPairs.push_back(std::make_pair<uint64_t, uint64_t>(FirstRange.Start, FirstRange.End - FirstRange.Start + 1));
+ m_BuildStoreStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(
+ HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Range count {} exceeds maximum of {}", RangeCount, MaxRangeCountPerRequestSupported));
}
- }
-
- m_BuildStoreStats.BlobReadCount++;
- IoBuffer Blob = m_BuildStore.GetBlob(BlobHash);
- if (!Blob)
- {
- return ServerRequest.WriteResponse(HttpResponseCode::NotFound,
- HttpContentType::kText,
- fmt::format("Blob with hash '{}' could not be found", Hash));
- }
- m_BuildStoreStats.BlobHitCount++;
- if (OffsetAndLengthPairs.empty())
- {
- return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob);
- }
+ const uint64_t BlobSize = Blob.GetSize();
+ std::vector<IoBuffer> RangeBuffers;
+ RangeBuffers.reserve(RangeCount);
- if (ServerRequest.AcceptContentType() == HttpContentType::kCbPackage)
- {
- const uint64_t BlobSize = Blob.GetSize();
+ CbPackage ResponsePackage;
+ CbObjectWriter Writer;
- CbPackage ResponsePackage;
- std::vector<IoBuffer> RangeBuffers;
- CbObjectWriter Writer;
Writer.BeginArray("ranges"sv);
- for (const std::pair<uint64_t, uint64_t>& Range : OffsetAndLengthPairs)
+ for (CbFieldView FieldView : RangesArray)
{
- const uint64_t MaxBlobSize = Range.first < BlobSize ? BlobSize - Range.first : 0;
- const uint64_t RangeSize = Min(Range.second, MaxBlobSize);
+ CbObjectView RangeView = FieldView.AsObjectView();
+ uint64_t RangeOffset = RangeView["offset"sv].AsUInt64();
+ uint64_t RangeLength = RangeView["length"sv].AsUInt64();
+
+ const uint64_t MaxBlobSize = RangeOffset < BlobSize ? BlobSize - RangeOffset : 0;
+ const uint64_t RangeSize = Min(RangeLength, MaxBlobSize);
Writer.BeginObject();
{
- if (Range.first + RangeSize <= BlobSize)
+ if (RangeOffset + RangeSize <= BlobSize)
{
- RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize));
- Writer.AddInteger("offset"sv, Range.first);
+ RangeBuffers.push_back(IoBuffer(Blob, RangeOffset, RangeSize));
+ Writer.AddInteger("offset"sv, RangeOffset);
Writer.AddInteger("length"sv, RangeSize);
}
else
{
- Writer.AddInteger("offset"sv, Range.first);
+ Writer.AddInteger("offset"sv, RangeOffset);
Writer.AddInteger("length"sv, 0);
}
}
@@ -259,7 +244,7 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req)
}
Writer.EndArray();
- CompositeBuffer Ranges(RangeBuffers);
+ CompositeBuffer Ranges(std::move(RangeBuffers));
CbAttachment PayloadAttachment(std::move(Ranges), BlobHash);
Writer.AddAttachment("payload", PayloadAttachment);
@@ -269,32 +254,21 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req)
ResponsePackage.SetObject(HeaderObject);
CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage);
- uint64_t ResponseSize = RpcResponseBuffer.GetSize();
- ZEN_UNUSED(ResponseSize);
return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer);
}
else
{
- if (OffsetAndLengthPairs.size() != 1)
+ HttpRanges RequestedRangeHeader;
+ bool HasRange = ServerRequest.TryGetRanges(RequestedRangeHeader);
+ if (HasRange)
{
- // Only a single http range is supported, we have limited support for http multirange responses
- m_BuildStoreStats.BadRequestCount++;
- return ServerRequest.WriteResponse(
- HttpResponseCode::BadRequest,
- HttpContentType::kText,
- fmt::format("Multiple ranges in blob request is only supported for {} accept type", ToString(HttpContentType::kCbPackage)));
+ // Standard HTTP GET with Range header: framework handles 206, Content-Range, and 416 on OOB.
+ return ServerRequest.WriteResponse(HttpContentType::kBinary, Blob, RequestedRangeHeader);
}
-
- const std::pair<uint64_t, uint64_t>& OffsetAndLength = OffsetAndLengthPairs.front();
- const uint64_t BlobSize = Blob.GetSize();
- const uint64_t MaxBlobSize = OffsetAndLength.first < BlobSize ? BlobSize - OffsetAndLength.first : 0;
- const uint64_t RangeSize = Min(OffsetAndLength.second, MaxBlobSize);
- if (OffsetAndLength.first + RangeSize > BlobSize)
+ else
{
- return ServerRequest.WriteResponse(HttpResponseCode::NoContent);
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob);
}
- Blob = IoBuffer(Blob, OffsetAndLength.first, RangeSize);
- return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob);
}
}
diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp
index c1727270c..8ad48225b 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.cpp
+++ b/src/zenserver/storage/cache/httpstructuredcache.cpp
@@ -80,7 +80,8 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach
HttpStatusService& StatusService,
UpstreamCache& UpstreamCache,
const DiskWriteBlocker* InDiskWriteBlocker,
- OpenProcessCache& InOpenProcessCache)
+ OpenProcessCache& InOpenProcessCache,
+ const ILocalRefPolicy* InLocalRefPolicy)
: m_Log(logging::Get("cache"))
, m_CacheStore(InCacheStore)
, m_StatsService(StatsService)
@@ -90,6 +91,7 @@ HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCach
, m_DiskWriteBlocker(InDiskWriteBlocker)
, m_OpenProcessCache(InOpenProcessCache)
, m_RpcHandler(m_Log, m_CacheStats, UpstreamCache, InCacheStore, InCidStore, InDiskWriteBlocker)
+, m_LocalRefPolicy(InLocalRefPolicy)
{
m_StatsService.RegisterHandler("z$", *this);
m_StatusService.RegisterHandler("z$", *this);
@@ -114,6 +116,18 @@ HttpStructuredCacheService::BaseUri() const
return "/z$/";
}
+bool
+HttpStructuredCacheService::AcceptsLocalFileReferences() const
+{
+ return true;
+}
+
+const ILocalRefPolicy*
+HttpStructuredCacheService::GetLocalRefPolicy() const
+{
+ return m_LocalRefPolicy;
+}
+
void
HttpStructuredCacheService::Flush()
{
diff --git a/src/zenserver/storage/cache/httpstructuredcache.h b/src/zenserver/storage/cache/httpstructuredcache.h
index fc80b449e..f606126d6 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.h
+++ b/src/zenserver/storage/cache/httpstructuredcache.h
@@ -76,11 +76,14 @@ public:
HttpStatusService& StatusService,
UpstreamCache& UpstreamCache,
const DiskWriteBlocker* InDiskWriteBlocker,
- OpenProcessCache& InOpenProcessCache);
+ OpenProcessCache& InOpenProcessCache,
+ const ILocalRefPolicy* InLocalRefPolicy = nullptr);
~HttpStructuredCacheService();
- virtual const char* BaseUri() const override;
- virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual bool AcceptsLocalFileReferences() const override;
+ virtual const ILocalRefPolicy* GetLocalRefPolicy() const override;
void Flush();
@@ -125,6 +128,7 @@ private:
const DiskWriteBlocker* m_DiskWriteBlocker = nullptr;
OpenProcessCache& m_OpenProcessCache;
CacheRpcHandler m_RpcHandler;
+ const ILocalRefPolicy* m_LocalRefPolicy = nullptr;
void ReplayRequestRecorder(const CacheRequestContext& Context, cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount);
diff --git a/src/zenserver/storage/localrefpolicy.cpp b/src/zenserver/storage/localrefpolicy.cpp
new file mode 100644
index 000000000..47ef13b28
--- /dev/null
+++ b/src/zenserver/storage/localrefpolicy.cpp
@@ -0,0 +1,29 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "localrefpolicy.h"
+
+#include <zencore/except_fmt.h>
+#include <zencore/fmtutils.h>
+
+#include <filesystem>
+
+namespace zen {
+
+DataRootLocalRefPolicy::DataRootLocalRefPolicy(const std::filesystem::path& DataRoot)
+: m_CanonicalRoot(std::filesystem::weakly_canonical(DataRoot).string())
+{
+}
+
+void
+DataRootLocalRefPolicy::ValidatePath(const std::filesystem::path& Path) const
+{
+ std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(Path);
+ std::string FileStr = CanonicalFile.string();
+
+ if (FileStr.size() < m_CanonicalRoot.size() || FileStr.compare(0, m_CanonicalRoot.size(), m_CanonicalRoot) != 0)
+ {
+ throw zen::invalid_argument("local file reference '{}' is outside allowed data root", CanonicalFile);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/storage/localrefpolicy.h b/src/zenserver/storage/localrefpolicy.h
new file mode 100644
index 000000000..3686d1880
--- /dev/null
+++ b/src/zenserver/storage/localrefpolicy.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/localrefpolicy.h>
+
+#include <filesystem>
+#include <string>
+
+namespace zen {
+
+/// Local ref policy that restricts file paths to a canonical data root directory.
+/// Uses weakly_canonical + string prefix comparison to detect path traversal.
+class DataRootLocalRefPolicy : public ILocalRefPolicy
+{
+public:
+ explicit DataRootLocalRefPolicy(const std::filesystem::path& DataRoot);
+
+ void ValidatePath(const std::filesystem::path& Path) const override;
+
+private:
+ std::string m_CanonicalRoot;
+};
+
+} // namespace zen
diff --git a/src/zenserver/storage/objectstore/objectstore.cpp b/src/zenserver/storage/objectstore/objectstore.cpp
index d6516fa1a..1115c1cd6 100644
--- a/src/zenserver/storage/objectstore/objectstore.cpp
+++ b/src/zenserver/storage/objectstore/objectstore.cpp
@@ -637,11 +637,7 @@ HttpObjectStoreService::GetObject(HttpRouterRequest& Request, const std::string_
}
HttpRanges Ranges;
- if (Request.ServerRequest().TryGetRanges(Ranges); Ranges.size() > 1)
- {
- // Only a single range is supported
- return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
- }
+ Request.ServerRequest().TryGetRanges(Ranges);
FileContents File;
{
@@ -665,42 +661,49 @@ HttpObjectStoreService::GetObject(HttpRouterRequest& Request, const std::string_
if (Ranges.empty())
{
- const uint64_t TotalServed = m_TotalBytesServed.fetch_add(FileBuf.Size()) + FileBuf.Size();
-
+ const uint64_t TotalServed = m_TotalBytesServed.fetch_add(FileBuf.GetSize()) + FileBuf.GetSize();
ZEN_LOG_DEBUG(LogObj,
"GET - '{}/{}' ({}) [OK] (Served: {})",
BucketName,
RelativeBucketPath,
- NiceBytes(FileBuf.Size()),
+ NiceBytes(FileBuf.GetSize()),
NiceBytes(TotalServed));
-
- Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, FileBuf);
}
else
{
- const auto Range = Ranges[0];
- const uint64_t RangeSize = 1 + (Range.End - Range.Start);
- const uint64_t TotalServed = m_TotalBytesServed.fetch_add(RangeSize) + RangeSize;
-
- ZEN_LOG_DEBUG(LogObj,
- "GET - '{}/{}' (Range: {}-{}) ({}/{}) [OK] (Served: {})",
- BucketName,
- RelativeBucketPath,
- Range.Start,
- Range.End,
- NiceBytes(RangeSize),
- NiceBytes(FileBuf.Size()),
- NiceBytes(TotalServed));
-
- MemoryView RangeView = FileBuf.GetView().Mid(Range.Start, RangeSize);
- if (RangeView.GetSize() != RangeSize)
+ const uint64_t TotalSize = FileBuf.GetSize();
+ uint64_t ServedBytes = 0;
+ for (const HttpRange& Range : Ranges)
{
- return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1;
+ if (RangeEnd < TotalSize && Range.Start <= RangeEnd)
+ {
+ ServedBytes += 1 + (RangeEnd - Range.Start);
+ }
+ }
+ if (ServedBytes > 0)
+ {
+ const uint64_t TotalServed = m_TotalBytesServed.fetch_add(ServedBytes) + ServedBytes;
+ ZEN_LOG_DEBUG(LogObj,
+ "GET - '{}/{}' (Ranges: {}) ({}/{}) [OK] (Served: {})",
+ BucketName,
+ RelativeBucketPath,
+ Ranges.size(),
+ NiceBytes(ServedBytes),
+ NiceBytes(TotalSize),
+ NiceBytes(TotalServed));
+ }
+ else
+ {
+ ZEN_LOG_DEBUG(LogObj,
+ "GET - '{}/{}' (Ranges: {}) [416] ({})",
+ BucketName,
+ RelativeBucketPath,
+ Ranges.size(),
+ NiceBytes(TotalSize));
}
-
- IoBuffer RangeBuf = IoBuffer(IoBuffer::Wrap, RangeView.GetData(), RangeView.GetSize());
- Request.ServerRequest().WriteResponse(HttpResponseCode::PartialContent, HttpContentType::kBinary, RangeBuf);
}
+ Request.ServerRequest().WriteResponse(HttpContentType::kBinary, FileBuf, Ranges);
}
void
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp
index a7c8c66b6..2a6c62195 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.cpp
+++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp
@@ -18,7 +18,6 @@
#include <zenremotestore/builds/buildstoragecache.h>
#include <zenremotestore/builds/buildstorageutil.h>
#include <zenremotestore/jupiter/jupiterhost.h>
-#include <zenremotestore/operationlogoutput.h>
#include <zenremotestore/projectstore/buildsremoteprojectstore.h>
#include <zenremotestore/projectstore/fileremoteprojectstore.h>
#include <zenremotestore/projectstore/jupiterremoteprojectstore.h>
@@ -279,7 +278,7 @@ namespace {
{
ZEN_MEMSCOPE(GetProjectHttpTag());
- auto Log = [InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -566,11 +565,9 @@ namespace {
.AllowResume = true,
.RetryCount = 2};
- std::unique_ptr<OperationLogOutput> Output(CreateStandardLogOutput(Log()));
-
try
{
- ResolveResult = ResolveBuildStorage(*Output,
+ ResolveResult = ResolveBuildStorage(Log(),
ClientSettings,
Host,
OverrideHost,
@@ -656,7 +653,8 @@ HttpProjectService::HttpProjectService(CidStore& Store,
JobQueue& InJobQueue,
bool InRestrictContentTypes,
const std::filesystem::path& InOidcTokenExePath,
- bool InAllowExternalOidcTokenExe)
+ bool InAllowExternalOidcTokenExe,
+ const ILocalRefPolicy* InLocalRefPolicy)
: m_Log(logging::Get("project"))
, m_CidStore(Store)
, m_ProjectStore(Projects)
@@ -668,6 +666,7 @@ HttpProjectService::HttpProjectService(CidStore& Store,
, m_RestrictContentTypes(InRestrictContentTypes)
, m_OidcTokenExePath(InOidcTokenExePath)
, m_AllowExternalOidcTokenExe(InAllowExternalOidcTokenExe)
+, m_LocalRefPolicy(InLocalRefPolicy)
{
ZEN_MEMSCOPE(GetProjectHttpTag());
@@ -785,22 +784,22 @@ HttpProjectService::HttpProjectService(CidStore& Store,
HttpVerb::kPost);
m_Router.RegisterRoute(
- "details\\$",
+ "details$",
[this](HttpRouterRequest& Req) { HandleDetailsRequest(Req); },
HttpVerb::kGet);
m_Router.RegisterRoute(
- "details\\$/{project}",
+ "details$/{project}",
[this](HttpRouterRequest& Req) { HandleProjectDetailsRequest(Req); },
HttpVerb::kGet);
m_Router.RegisterRoute(
- "details\\$/{project}/{log}",
+ "details$/{project}/{log}",
[this](HttpRouterRequest& Req) { HandleOplogDetailsRequest(Req); },
HttpVerb::kGet);
m_Router.RegisterRoute(
- "details\\$/{project}/{log}/{chunk}",
+ "details$/{project}/{log}/{chunk}",
[this](HttpRouterRequest& Req) { HandleOplogOpDetailsRequest(Req); },
HttpVerb::kGet);
@@ -820,6 +819,18 @@ HttpProjectService::BaseUri() const
return "/prj/";
}
+bool
+HttpProjectService::AcceptsLocalFileReferences() const
+{
+ return true;
+}
+
+const ILocalRefPolicy*
+HttpProjectService::GetLocalRefPolicy() const
+{
+ return m_LocalRefPolicy;
+}
+
void
HttpProjectService::HandleRequest(HttpServerRequest& Request)
{
@@ -1250,7 +1261,7 @@ HttpProjectService::HandleChunkInfoRequest(HttpRouterRequest& Req)
const Oid Obj = Oid::FromHexString(ChunkId);
- CbObject ResponsePayload = ProjectStore::GetChunkInfo(Log(), *Project, *FoundLog, Obj);
+ CbObject ResponsePayload = ProjectStore::GetChunkInfo(*Project, *FoundLog, Obj);
if (ResponsePayload)
{
m_ProjectStats.ChunkHitCount++;
@@ -1339,7 +1350,7 @@ HttpProjectService::HandleChunkByIdRequest(HttpRouterRequest& Req)
HttpContentType AcceptType = HttpReq.AcceptContentType();
ProjectStore::GetChunkRangeResult Result =
- ProjectStore::GetChunkRange(Log(), *Project, *FoundLog, Obj, Offset, Size, AcceptType, /*OptionalInOutModificationTag*/ nullptr);
+ ProjectStore::GetChunkRange(*Project, *FoundLog, Obj, Offset, Size, AcceptType, /*OptionalInOutModificationTag*/ nullptr);
switch (Result.Error)
{
@@ -1668,7 +1679,8 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req)
CbPackage Package;
- if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver))
+ const bool ValidateHashes = false;
+ if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver, ValidateHashes))
{
CbValidateError ValidateResult;
if (CbObject Core = ValidateAndReadCompactBinaryObject(IoBuffer(Payload), ValidateResult);
@@ -2676,6 +2688,7 @@ HttpProjectService::HandleOplogLoadRequest(HttpRouterRequest& Req)
try
{
CbObject ContainerObject = BuildContainer(
+ Log(),
m_CidStore,
*Project,
*Oplog,
@@ -2763,7 +2776,11 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req)
case HttpContentType::kCbPackage:
try
{
- Package = ParsePackageMessage(Payload);
+ ParseFlags PkgFlags = (HttpReq.IsLocalMachineRequest() && AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ const ILocalRefPolicy* PkgPolicy =
+ EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? GetLocalRefPolicy() : nullptr;
+ Package = ParsePackageMessage(Payload, {}, PkgFlags, PkgPolicy);
Cb = Package.GetObject();
}
catch (const std::invalid_argument& ex)
@@ -2872,6 +2889,7 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req)
try
{
LoadOplog(LoadOplogContext{
+ .Log = Log(),
.ChunkStore = m_CidStore,
.RemoteStore = *RemoteStoreResult->Store,
.OptionalCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Cache.get() : nullptr,
@@ -2997,7 +3015,8 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req)
try
{
- SaveOplog(m_CidStore,
+ SaveOplog(Log(),
+ m_CidStore,
*ActualRemoteStore,
*Project,
*Oplog,
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.h b/src/zenserver/storage/projectstore/httpprojectstore.h
index e3ed02f26..8aa345fa7 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.h
+++ b/src/zenserver/storage/projectstore/httpprojectstore.h
@@ -47,11 +47,14 @@ public:
JobQueue& InJobQueue,
bool InRestrictContentTypes,
const std::filesystem::path& InOidcTokenExePath,
- bool AllowExternalOidcTokenExe);
+ bool AllowExternalOidcTokenExe,
+ const ILocalRefPolicy* InLocalRefPolicy = nullptr);
~HttpProjectService();
- virtual const char* BaseUri() const override;
- virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual bool AcceptsLocalFileReferences() const override;
+ virtual const ILocalRefPolicy* GetLocalRefPolicy() const override;
virtual void HandleStatusRequest(HttpServerRequest& Request) override;
virtual void HandleStatsRequest(HttpServerRequest& Request) override;
@@ -117,6 +120,7 @@ private:
bool m_RestrictContentTypes;
std::filesystem::path m_OidcTokenExePath;
bool m_AllowExternalOidcTokenExe;
+ const ILocalRefPolicy* m_LocalRefPolicy;
Ref<TransferThreadWorkers> GetThreadWorkers(bool BoostWorkers, bool SingleThreaded);
};
diff --git a/src/zenserver/storage/storageconfig.cpp b/src/zenserver/storage/storageconfig.cpp
index 0dbb45164..bb4f053e4 100644
--- a/src/zenserver/storage/storageconfig.cpp
+++ b/src/zenserver/storage/storageconfig.cpp
@@ -57,6 +57,12 @@ ZenStorageServerConfigurator::ValidateOptions()
ZEN_WARN("'--gc-v2=false' is deprecated, reverting to '--gc-v2=true'");
ServerOptions.GcConfig.UseGCV2 = true;
}
+ if (ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent > 100)
+ {
+ throw OptionParseException(fmt::format("'--buildstore-disksizelimit-percent' ('{}') is invalid, must be between 1 and 100.",
+ ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent),
+ {});
+ }
}
class ZenStructuredCacheBucketsConfigOption : public LuaConfig::OptionValue
@@ -382,6 +388,9 @@ ZenStorageServerConfigurator::AddConfigOptions(LuaConfig::Options& LuaOptions)
////// buildsstore
LuaOptions.AddOption("server.buildstore.enabled"sv, ServerOptions.BuildStoreConfig.Enabled, "buildstore-enabled"sv);
LuaOptions.AddOption("server.buildstore.disksizelimit"sv, ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit, "buildstore-disksizelimit");
+ LuaOptions.AddOption("server.buildstore.disksizelimitpercent"sv,
+ ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent,
+ "buildstore-disksizelimit-percent");
////// cache
LuaOptions.AddOption("cache.enable"sv, ServerOptions.StructuredCacheConfig.Enabled);
@@ -477,7 +486,7 @@ ZenStorageServerConfigurator::AddConfigOptions(LuaConfig::Options& LuaOptions)
ServerOptions.GcConfig.CompactBlockUsageThresholdPercent,
"gc-compactblock-threshold"sv);
LuaOptions.AddOption("gc.verbose"sv, ServerOptions.GcConfig.Verbose, "gc-verbose"sv);
- LuaOptions.AddOption("gc.single-threaded"sv, ServerOptions.GcConfig.SingleThreaded, "gc-single-threaded"sv);
+ LuaOptions.AddOption("gc.singlethreaded"sv, ServerOptions.GcConfig.SingleThreaded, "gc-single-threaded"sv);
LuaOptions.AddOption("gc.cache.attachment.store"sv, ServerOptions.GcConfig.StoreCacheAttachmentMetaData, "gc-cache-attachment-store");
LuaOptions.AddOption("gc.projectstore.attachment.store"sv,
ServerOptions.GcConfig.StoreProjectAttachmentMetaData,
@@ -1035,6 +1044,13 @@ ZenStorageServerCmdLineOptions::AddBuildStoreOptions(cxxopts::Options& options,
"Max number of bytes before build store entries get evicted. Default set to 1099511627776 (1TB week)",
cxxopts::value<uint64_t>(ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit)->default_value("1099511627776"),
"");
+ options.add_option("buildstore",
+ "",
+ "buildstore-disksizelimit-percent",
+ "Max percentage (1-100) of total drive capacity (of --data-dir drive) before build store entries get evicted. "
+ "0 (default) disables this limit. When combined with --buildstore-disksizelimit, the lower value wins.",
+ cxxopts::value<uint32_t>(ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent)->default_value("0"),
+ "");
}
void
diff --git a/src/zenserver/storage/storageconfig.h b/src/zenserver/storage/storageconfig.h
index 18af4f096..fec8fd70b 100644
--- a/src/zenserver/storage/storageconfig.h
+++ b/src/zenserver/storage/storageconfig.h
@@ -135,8 +135,9 @@ struct ZenProjectStoreConfig
struct ZenBuildStoreConfig
{
- bool Enabled = false;
- uint64_t MaxDiskSpaceLimit = 1u * 1024u * 1024u * 1024u * 1024u; // 1TB
+ bool Enabled = false;
+ uint64_t MaxDiskSpaceLimit = 1u * 1024u * 1024u * 1024u * 1024u; // 1TB
+ uint32_t MaxDiskSpaceLimitPercent = 0;
};
struct ZenWorkspacesConfig
diff --git a/src/zenserver/storage/upstream/upstreamcache.cpp b/src/zenserver/storage/upstream/upstreamcache.cpp
index b26c57414..a516c452c 100644
--- a/src/zenserver/storage/upstream/upstreamcache.cpp
+++ b/src/zenserver/storage/upstream/upstreamcache.cpp
@@ -772,7 +772,7 @@ namespace detail {
UpstreamEndpointInfo m_Info;
UpstreamStatus m_Status;
UpstreamEndpointStats m_Stats;
- RefPtr<JupiterClient> m_Client;
+ Ref<JupiterClient> m_Client;
const bool m_AllowRedirect = false;
};
@@ -1446,7 +1446,7 @@ namespace detail {
// Make sure we safely bump the refcount inside a scope lock
RwLock::SharedLockScope _(m_ClientLock);
ZEN_ASSERT(m_Client);
- Ref<ZenStructuredCacheClient> ClientRef(m_Client);
+ Ref<ZenStructuredCacheClient> ClientRef(m_Client.Get());
_.ReleaseNow();
return ClientRef;
}
@@ -1485,15 +1485,15 @@ namespace detail {
LoggerRef Log() { return m_Log; }
- LoggerRef m_Log;
- UpstreamEndpointInfo m_Info;
- UpstreamStatus m_Status;
- UpstreamEndpointStats m_Stats;
- std::vector<ZenEndpoint> m_Endpoints;
- std::chrono::milliseconds m_ConnectTimeout;
- std::chrono::milliseconds m_Timeout;
- RwLock m_ClientLock;
- RefPtr<ZenStructuredCacheClient> m_Client;
+ LoggerRef m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ std::vector<ZenEndpoint> m_Endpoints;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+ RwLock m_ClientLock;
+ Ref<ZenStructuredCacheClient> m_Client;
};
} // namespace detail
diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp
index bc0a8f4ac..7b52f2832 100644
--- a/src/zenserver/storage/zenstorageserver.cpp
+++ b/src/zenserver/storage/zenstorageserver.cpp
@@ -37,8 +37,6 @@
#include <zenutil/sessionsclient.h>
#include <zenutil/workerpools.h>
#include <zenutil/zenserverprocess.h>
-#include "sessions/inprocsessionlogsink.h"
-#include "sessions/sessions.h"
#if ZEN_PLATFORM_WINDOWS
# include <zencore/windows.h>
@@ -165,11 +163,6 @@ ZenStorageServer::RegisterServices()
m_Http->RegisterService(*m_HttpWorkspacesService);
}
- if (m_HttpSessionsService)
- {
- m_Http->RegisterService(*m_HttpSessionsService);
- }
-
m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService);
if (m_FrontendService)
@@ -223,12 +216,13 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
ZEN_INFO("instantiating project service");
+ m_LocalRefPolicy = std::make_unique<DataRootLocalRefPolicy>(m_DataRoot);
m_JobQueue = MakeJobQueue(8, "bgjobs");
m_OpenProcessCache = std::make_unique<OpenProcessCache>();
m_ProjectStore = new ProjectStore(*m_CidStore, m_DataRoot / "projects", m_GcManager, ProjectStore::Configuration{});
m_HttpProjectService.reset(new HttpProjectService{*m_CidStore,
- m_ProjectStore,
+ m_ProjectStore.Get(),
m_StatusService,
m_StatsService,
*m_AuthMgr,
@@ -236,7 +230,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
*m_JobQueue,
ServerOptions.RestrictContentTypes,
ServerOptions.OidcTokenExecutable,
- ServerOptions.AllowExternalOidcTokenExe});
+ ServerOptions.AllowExternalOidcTokenExe,
+ m_LocalRefPolicy.get()});
if (ServerOptions.WorksSpacesConfig.Enabled)
{
@@ -251,16 +246,6 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
*m_Workspaces));
}
- {
- m_SessionsService = std::make_unique<SessionsService>();
- m_HttpSessionsService = std::make_unique<HttpSessionsService>(m_StatusService, m_StatsService, *m_SessionsService, m_IoContext);
- m_HttpSessionsService->SetSelfSessionId(GetSessionId());
-
- m_InProcSessionLogSink = logging::SinkPtr(new InProcSessionLogSink(*m_SessionsService));
- m_InProcSessionLogSink->SetLevel(logging::Info);
- GetDefaultBroadcastSink()->AddSink(m_InProcSessionLogSink);
- }
-
if (!ServerOptions.SessionsTargetUrl.empty())
{
m_SessionsClient = std::make_unique<SessionsServiceClient>(SessionsServiceClient::Options{
@@ -281,7 +266,31 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
BuildStoreConfig BuildsCfg;
BuildsCfg.RootDirectory = m_DataRoot / "builds";
BuildsCfg.MaxDiskSpaceLimit = ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit;
- m_BuildStore = std::make_unique<BuildStore>(std::move(BuildsCfg), m_GcManager, *m_BuildCidStore);
+
+ if (ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent > 0)
+ {
+ DiskSpace Space;
+ if (DiskSpaceInfo(m_DataRoot, Space) && Space.Total > 0)
+ {
+ uint64_t PercentLimit = Space.Total * ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent / 100;
+ BuildsCfg.MaxDiskSpaceLimit = ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit > 0
+ ? std::min(ServerOptions.BuildStoreConfig.MaxDiskSpaceLimit, PercentLimit)
+ : PercentLimit;
+ ZEN_INFO("buildstore disk limit: {}% of {} = {} (effective limit: {})",
+ ServerOptions.BuildStoreConfig.MaxDiskSpaceLimitPercent,
+ NiceBytes(Space.Total),
+ NiceBytes(PercentLimit),
+ NiceBytes(BuildsCfg.MaxDiskSpaceLimit));
+ }
+ else
+ {
+ ZEN_WARN("buildstore-disksizelimit-percent: failed to query disk space for {}, using absolute limit {}",
+ m_DataRoot.string(),
+ NiceBytes(BuildsCfg.MaxDiskSpaceLimit));
+ }
+ }
+
+ m_BuildStore = std::make_unique<BuildStore>(std::move(BuildsCfg), m_GcManager, *m_BuildCidStore);
}
if (ServerOptions.StructuredCacheConfig.Enabled)
@@ -323,13 +332,13 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions
ZEN_OTEL_SPAN("InitializeComputeService");
m_HttpComputeService =
- std::make_unique<compute::HttpComputeService>(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions");
+ std::make_unique<compute::HttpComputeService>(*m_CidStore, *m_CidStore, m_StatsService, ServerOptions.DataDir / "functions");
}
#endif
#if ZEN_WITH_VFS
m_VfsServiceImpl = std::make_unique<VfsServiceImpl>();
- m_VfsServiceImpl->AddService(Ref<ProjectStore>(m_ProjectStore));
+ m_VfsServiceImpl->AddService(Ref<ProjectStore>(m_ProjectStore.Get()));
m_VfsServiceImpl->AddService(Ref<ZenCacheStore>(m_CacheStore));
m_VfsService = std::make_unique<VfsService>(m_StatusService, m_VfsServiceImpl.get());
@@ -713,7 +722,8 @@ ZenStorageServer::InitializeStructuredCache(const ZenStorageServerConfig& Server
m_StatusService,
*m_UpstreamCache,
m_GcManager.GetDiskWriteBlocker(),
- *m_OpenProcessCache);
+ *m_OpenProcessCache,
+ m_LocalRefPolicy.get());
m_StatsReporter.AddProvider(m_CacheStore.Get());
m_StatsReporter.AddProvider(m_CidStore.get());
@@ -838,7 +848,7 @@ ZenStorageServer::Run()
OnReady();
- m_SessionsService->RegisterSession(GetSessionId(), "zenserver", GetServerMode(), Oid::Zero, {});
+ StartSelfSession("zenserver");
if (m_SessionsClient)
{
@@ -888,11 +898,6 @@ ZenStorageServer::Cleanup()
m_Http->Close();
}
- if (m_InProcSessionLogSink)
- {
- GetDefaultBroadcastSink()->RemoveSink(m_InProcSessionLogSink);
- m_InProcSessionLogSink = {};
- }
if (m_SessionLogSink)
{
GetDefaultBroadcastSink()->RemoveSink(m_SessionLogSink);
@@ -904,11 +909,6 @@ ZenStorageServer::Cleanup()
m_SessionsClient.reset();
}
- if (m_SessionsService)
- {
- m_SessionsService->RemoveSession(GetSessionId());
- }
-
ShutdownServices();
if (m_JobQueue)
@@ -940,8 +940,6 @@ ZenStorageServer::Cleanup()
m_UpstreamCache.reset();
m_CacheStore = {};
- m_HttpSessionsService.reset();
- m_SessionsService.reset();
m_HttpWorkspacesService.reset();
m_Workspaces.reset();
m_HttpProjectService.reset();
diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h
index fad22ad54..9fa46ba9b 100644
--- a/src/zenserver/storage/zenstorageserver.h
+++ b/src/zenserver/storage/zenstorageserver.h
@@ -11,6 +11,7 @@
#include <zenstore/cache/structuredcachestore.h>
#include <zenstore/gc.h>
#include <zenstore/projectstore.h>
+#include "localrefpolicy.h"
#include "admin/admin.h"
#include "buildstore/httpbuildstore.h"
@@ -19,7 +20,6 @@
#include "frontend/frontend.h"
#include "objectstore/objectstore.h"
#include "projectstore/httpprojectstore.h"
-#include "sessions/httpsessions.h"
#include "stats/statsreporter.h"
#include "upstream/upstream.h"
#include "vfs/vfsservice.h"
@@ -65,27 +65,26 @@ private:
void InitializeServices(const ZenStorageServerConfig& ServerOptions);
void RegisterServices();
- std::unique_ptr<JobQueue> m_JobQueue;
- GcManager m_GcManager;
- GcScheduler m_GcScheduler{m_GcManager};
- std::unique_ptr<CidStore> m_CidStore;
- Ref<ZenCacheStore> m_CacheStore;
- std::unique_ptr<OpenProcessCache> m_OpenProcessCache;
- HttpTestService m_TestService;
- std::unique_ptr<CidStore> m_BuildCidStore;
- std::unique_ptr<BuildStore> m_BuildStore;
+ std::unique_ptr<DataRootLocalRefPolicy> m_LocalRefPolicy;
+ std::unique_ptr<JobQueue> m_JobQueue;
+ GcManager m_GcManager;
+ GcScheduler m_GcScheduler{m_GcManager};
+ std::unique_ptr<CidStore> m_CidStore;
+ Ref<ZenCacheStore> m_CacheStore;
+ std::unique_ptr<OpenProcessCache> m_OpenProcessCache;
+ HttpTestService m_TestService;
+ std::unique_ptr<CidStore> m_BuildCidStore;
+ std::unique_ptr<BuildStore> m_BuildStore;
#if ZEN_WITH_TESTS
HttpTestingService m_TestingService;
#endif
- RefPtr<ProjectStore> m_ProjectStore;
+ Ref<ProjectStore> m_ProjectStore;
std::unique_ptr<VfsServiceImpl> m_VfsServiceImpl;
std::unique_ptr<HttpProjectService> m_HttpProjectService;
std::unique_ptr<Workspaces> m_Workspaces;
std::unique_ptr<HttpWorkspacesService> m_HttpWorkspacesService;
- std::unique_ptr<SessionsService> m_SessionsService;
- std::unique_ptr<HttpSessionsService> m_HttpSessionsService;
std::unique_ptr<UpstreamCache> m_UpstreamCache;
std::unique_ptr<HttpUpstreamService> m_UpstreamService;
std::unique_ptr<HttpStructuredCacheService> m_StructuredCacheService;
@@ -98,7 +97,6 @@ private:
std::unique_ptr<SessionsServiceClient> m_SessionsClient;
logging::SinkPtr m_SessionLogSink;
- logging::SinkPtr m_InProcSessionLogSink;
asio::steady_timer m_SessionAnnounceTimer{m_IoContext};
void EnqueueSessionAnnounceTimer();
diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua
index c2c81e7aa..b609d1050 100644
--- a/src/zenserver/xmake.lua
+++ b/src/zenserver/xmake.lua
@@ -32,7 +32,7 @@ target("zenserver")
add_deps("protozero", "asio", "cxxopts")
add_deps("sol2")
- add_packages("http_parser")
+ add_packages("llhttp")
add_packages("json11")
add_packages("zlib")
add_packages("lua")
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
index 6aa02eb87..e68e46bd6 100644
--- a/src/zenserver/zenserver.cpp
+++ b/src/zenserver/zenserver.cpp
@@ -13,6 +13,7 @@
#include <zencore/iobuffer.h>
#include <zencore/jobqueue.h>
#include <zencore/logging.h>
+#include <zencore/logging/broadcastsink.h>
#include <zencore/memory/fmalloc.h>
#include <zencore/scopeguard.h>
#include <zencore/sentryintegration.h>
@@ -28,6 +29,7 @@
#include <zenhttp/security/passwordsecurityfilter.h>
#include <zentelemetry/otlptrace.h>
#include <zenutil/authutils.h>
+#include <zenutil/logging.h>
#include <zenutil/service.h>
#include <zenutil/workerpools.h>
#include <zenutil/zenserverprocess.h>
@@ -64,6 +66,9 @@ ZEN_THIRD_PARTY_INCLUDES_END
#include "config/config.h"
#include "diag/logging.h"
+#include "sessions/httpsessions.h"
+#include "sessions/inprocsessionlogsink.h"
+#include "sessions/sessions.h"
#include <zencore/memory/llm.h>
@@ -225,6 +230,8 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState::
LogSettingsSummary(ServerOptions);
+ InitializeSessions();
+
return EffectiveBasePort;
}
@@ -233,6 +240,11 @@ ZenServerBase::Finalize()
{
m_StatsService.RegisterHandler("http", *m_Http);
+ if (m_HttpSessionsService)
+ {
+ m_Http->RegisterService(*m_HttpSessionsService);
+ }
+
m_Http->SetDefaultRedirect("/dashboard/");
// Register health service last so if we return "OK" for health it means all services have been properly initialized
@@ -243,11 +255,49 @@ ZenServerBase::Finalize()
void
ZenServerBase::ShutdownServices()
{
- m_StatsService.UnregisterHandler("http", *m_Http);
+ if (m_InProcSessionLogSink)
+ {
+ GetDefaultBroadcastSink()->RemoveSink(m_InProcSessionLogSink);
+ m_InProcSessionLogSink = {};
+ }
+
+ if (m_SessionsService)
+ {
+ m_SessionsService->RemoveSession(GetSessionId());
+ }
+
+ m_HttpSessionsService.reset();
+ m_SessionsService.reset();
+
+ if (m_Http)
+ {
+ m_StatsService.UnregisterHandler("http", *m_Http);
+ }
m_StatsService.Shutdown();
}
void
+ZenServerBase::InitializeSessions()
+{
+ m_SessionsService = std::make_unique<SessionsService>();
+ m_HttpSessionsService = std::make_unique<HttpSessionsService>(m_StatusService, m_StatsService, *m_SessionsService, m_IoContext);
+ m_HttpSessionsService->SetSelfSessionId(GetSessionId());
+
+ m_InProcSessionLogSink = logging::SinkPtr(new InProcSessionLogSink(*m_SessionsService));
+ m_InProcSessionLogSink->SetLevel(logging::Info);
+ GetDefaultBroadcastSink()->AddSink(m_InProcSessionLogSink);
+}
+
+void
+ZenServerBase::StartSelfSession(std::string_view AppName)
+{
+ if (m_SessionsService)
+ {
+ m_SessionsService->RegisterSession(GetSessionId(), std::string(AppName), GetServerMode(), Oid::Zero, {});
+ }
+}
+
+void
ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) const
{
ZEN_MEMSCOPE(GetZenserverTag());
@@ -272,6 +322,8 @@ ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) co
OutOptions << Separator;
OutOptions << "ZEN_WITH_MEMTRACK=" << (ZEN_WITH_MEMTRACK ? "1" : "0");
OutOptions << Separator;
+ OutOptions << "ZEN_WITH_COMPUTE_SERVICES=" << (ZEN_WITH_COMPUTE_SERVICES ? "1" : "0");
+ OutOptions << Separator;
OutOptions << "ZEN_WITH_TRACE=" << (ZEN_WITH_TRACE ? "1" : "0");
}
@@ -697,7 +749,7 @@ ZenServerMain::Run()
// The entry's process failed to pick up our sponsor request after
// multiple attempts. Before reclaiming the entry, verify that the
// PID does not still belong to a zenserver process. If it does, the
- // server is alive but unresponsive – fall back to the original error
+ // server is alive but unresponsive - fall back to the original error
// path. If the PID is gone or belongs to a different executable the
// entry is genuinely stale and safe to reclaim.
const int StalePid = Entry->Pid.load();
@@ -715,7 +767,7 @@ ZenServerMain::Run()
}
ZEN_CONSOLE_WARN(
"Failed to add sponsor to process on port {} (pid {}); "
- "pid belongs to '{}' – assuming stale entry and reclaiming",
+ "pid belongs to '{}' - assuming stale entry and reclaiming",
m_ServerOptions.BasePort,
StalePid,
ExeEc ? "<unknown>" : PidExePath.filename().string());
diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h
index f5286e9ee..995ff054f 100644
--- a/src/zenserver/zenserver.h
+++ b/src/zenserver/zenserver.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/basicfile.h>
+#include <zencore/logging/sink.h>
#include <zencore/system.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/httpstats.h>
@@ -27,6 +28,8 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
+class HttpSessionsService;
+class SessionsService;
struct FLLMTag;
extern const FLLMTag& GetZenserverTag();
@@ -57,6 +60,7 @@ protected:
int Initialize(const ZenServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry);
void Finalize();
void ShutdownServices();
+ void StartSelfSession(std::string_view AppName);
void GetBuildOptions(StringBuilderBase& OutOptions, char Separator = ',') const;
static std::vector<std::pair<std::string_view, std::string>> BuildSettingsList(const ZenServerConfig& ServerConfig);
void LogSettingsSummary(const ZenServerConfig& ServerConfig);
@@ -104,6 +108,11 @@ protected:
HttpStatusService m_StatusService;
SystemMetricsTracker m_MetricsTracker;
+ // Sessions (shared by all derived servers)
+ std::unique_ptr<SessionsService> m_SessionsService;
+ std::unique_ptr<HttpSessionsService> m_HttpSessionsService;
+ logging::SinkPtr m_InProcSessionLogSink;
+
// Stats reporting
StatsReporter m_StatsReporter;
@@ -137,6 +146,7 @@ protected:
virtual void HandleStatusRequest(HttpServerRequest& Request) override;
private:
+ void InitializeSessions();
void InitializeSecuritySettings(const ZenServerConfig& ServerOptions);
};
class ZenServerMain
diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp
index dff1c3c61..a08741d31 100644
--- a/src/zenstore/buildstore/buildstore.cpp
+++ b/src/zenstore/buildstore/buildstore.cpp
@@ -1052,7 +1052,7 @@ public:
{
ZEN_TRACE_CPU("Builds::PreCache");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -1107,7 +1107,7 @@ public:
ZEN_TRACE_CPU("Builds::GetUnusedReferences");
ZEN_MEMSCOPE(GetBuildstoreTag());
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
size_t InitialCount = IoCids.size();
size_t UsedCount = InitialCount;
@@ -1152,7 +1152,7 @@ BuildStore::RemoveExpiredData(GcCtx& Ctx, GcStats& Stats)
ZEN_TRACE_CPU("Builds::RemoveExpiredData");
ZEN_MEMSCOPE(GetBuildstoreTag());
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp
index 4640309d9..45a4b6456 100644
--- a/src/zenstore/cache/cachedisklayer.cpp
+++ b/src/zenstore/cache/cachedisklayer.cpp
@@ -3083,7 +3083,7 @@ public:
{
ZEN_TRACE_CPU("Z$::Bucket::CompactStore");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -3338,7 +3338,7 @@ ZenCacheDiskLayer::CacheBucket::RemoveExpiredData(GcCtx& Ctx, GcStats& Stats)
{
ZEN_TRACE_CPU("Z$::Bucket::RemoveExpiredData");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
size_t TotalEntries = 0;
@@ -3502,7 +3502,7 @@ ZenCacheDiskLayer::CacheBucket::GetReferences(const LoggerRef& Logger,
{
ZEN_TRACE_CPU("Z$::Bucket::GetReferencesLocked");
- auto Log = [&Logger]() { return Logger; };
+ ZEN_SCOPED_LOG(Logger);
auto GetAttachments = [&](const IoHash& RawHash, MemoryView Data) -> bool {
if (CbValidateError Error = ValidateCompactBinary(Data, CbValidateMode::Default); Error == CbValidateError::None)
@@ -3718,7 +3718,7 @@ public:
{
ZEN_TRACE_CPU("Z$::Bucket::PreCache");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -3753,7 +3753,7 @@ public:
{
ZEN_TRACE_CPU("Z$::Bucket::UpdateLockedState");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -3784,7 +3784,7 @@ public:
{
ZEN_TRACE_CPU("Z$::Bucket::GetUnusedReferences");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
const size_t InitialCount = IoCids.size();
size_t UsedCount = InitialCount;
@@ -3818,7 +3818,7 @@ ZenCacheDiskLayer::CacheBucket::CreateReferenceCheckers(GcCtx& Ctx)
{
ZEN_TRACE_CPU("Z$::Bucket::CreateReferenceCheckers");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp
index cff0e9a35..97b793083 100644
--- a/src/zenstore/cache/structuredcachestore.cpp
+++ b/src/zenstore/cache/structuredcachestore.cpp
@@ -468,7 +468,7 @@ ZenCacheStore::LogWorker()
LoggerRef ZCacheLog(logging::Get("z$"));
- auto Log = [&ZCacheLog]() -> LoggerRef { return ZCacheLog; };
+ ZEN_SCOPED_LOG(ZCacheLog);
std::vector<AccessLogItem> Items;
while (true)
@@ -1086,11 +1086,9 @@ ZenCacheStore::GetBucketInfo(std::string_view NamespaceName, std::string_view Bu
std::vector<RwLock::SharedLockScope>
ZenCacheStore::LockState(GcCtx& Ctx)
{
+ ZEN_UNUSED(Ctx);
ZEN_TRACE_CPU("CacheStore::LockState");
- auto Log = [&Ctx]() { return Ctx.Logger; };
- ZEN_UNUSED(Log);
-
std::vector<RwLock::SharedLockScope> Locks;
Locks.emplace_back(RwLock::SharedLockScope(m_NamespacesLock));
for (auto& NamespaceIt : m_Namespaces)
@@ -1211,7 +1209,7 @@ public:
{
ZEN_TRACE_CPU("Z$::UpdateLockedState");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
@@ -1276,7 +1274,7 @@ public:
{
ZEN_TRACE_CPU("Z$::GetUnusedReferences");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
const size_t InitialCount = IoCids.size();
size_t UsedCount = InitialCount;
@@ -1309,7 +1307,7 @@ ZenCacheStore::CreateReferenceCheckers(GcCtx& Ctx)
{
ZEN_TRACE_CPU("CacheStore::CreateReferenceCheckers");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp
index b20d8f565..ac8a75a58 100644
--- a/src/zenstore/cidstore.cpp
+++ b/src/zenstore/cidstore.cpp
@@ -188,12 +188,24 @@ CidStore::Initialize(const CidStoreConfiguration& Config)
}
CidStore::InsertResult
+CidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash)
+{
+ return m_Impl->AddChunk(ChunkData, RawHash, InsertMode::kMayBeMovedInPlace);
+}
+
+CidStore::InsertResult
CidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode)
{
return m_Impl->AddChunk(ChunkData, RawHash, Mode);
}
std::vector<CidStore::InsertResult>
+CidStore::AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes)
+{
+ return m_Impl->AddChunks(ChunkDatas, RawHashes, InsertMode::kMayBeMovedInPlace);
+}
+
+std::vector<CidStore::InsertResult>
CidStore::AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes, InsertMode Mode)
{
return m_Impl->AddChunks(ChunkDatas, RawHashes, Mode);
diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp
index 43dc389e2..6f1e1d701 100644
--- a/src/zenstore/compactcas.cpp
+++ b/src/zenstore/compactcas.cpp
@@ -698,7 +698,7 @@ public:
ZEN_TRACE_CPU("CasContainer::CompactStore");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -875,7 +875,7 @@ public:
ZEN_MEMSCOPE(GetCasContainerTag());
ZEN_TRACE_CPU("CasContainer::RemoveUnreferencedData");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -958,7 +958,7 @@ CasContainerStrategy::CreateReferencePruner(GcCtx& Ctx, GcReferenceStoreStats&)
ZEN_MEMSCOPE(GetCasContainerTag());
ZEN_TRACE_CPU("CasContainer::CreateReferencePruner");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -1391,7 +1391,7 @@ TEST_CASE("compactcas.compact.gc")
{
ScopedTemporaryDirectory TempDir;
- const int kIterationCount = 1000;
+ const int kIterationCount = 200;
std::vector<IoHash> Keys(kIterationCount);
@@ -1504,7 +1504,7 @@ TEST_CASE("compactcas.threadedinsert")
ScopedTemporaryDirectory TempDir;
const uint64_t kChunkSize = 1048;
- const int32_t kChunkCount = 2048;
+ const int32_t kChunkCount = 512;
uint64_t ExpectedSize = 0;
tsl::robin_map<IoHash, IoBuffer, IoHash::Hasher> Chunks;
@@ -1803,7 +1803,7 @@ TEST_CASE("compactcas.restart")
}
const uint64_t kChunkSize = 1048 + 395;
- const size_t kChunkCount = 7167;
+ const size_t kChunkCount = 2000;
std::vector<IoHash> Hashes;
Hashes.reserve(kChunkCount);
@@ -1984,9 +1984,8 @@ TEST_CASE("compactcas.iteratechunks")
WorkerThreadPool ThreadPool(Max(GetHardwareConcurrency() - 1u, 2u), "put");
const uint64_t kChunkSize = 1048 + 395;
- const size_t kChunkCount = 63840;
+ const size_t kChunkCount = 10000;
- for (uint32_t N = 0; N < 2; N++)
{
GcManager Gc;
CasContainerStrategy Cas(Gc);
@@ -2017,7 +2016,7 @@ TEST_CASE("compactcas.iteratechunks")
size_t BatchCount = Min<size_t>(kChunkCount - Offset, 512u);
WorkLatch.AddCount(1);
ThreadPool.ScheduleWork(
- [N, &WorkLatch, &InsertLock, &ChunkHashesLookup, &ExpectedSize, &Hashes, &Cas, Offset, BatchCount]() {
+ [&WorkLatch, &InsertLock, &ChunkHashesLookup, &ExpectedSize, &Hashes, &Cas, Offset, BatchCount]() {
auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); });
std::vector<IoBuffer> BatchBlobs;
@@ -2028,7 +2027,7 @@ TEST_CASE("compactcas.iteratechunks")
while (BatchBlobs.size() < BatchCount)
{
IoBuffer Chunk = CreateRandomBlob(
- N + kChunkSize + ((BatchHashes.size() % 100) + (BatchHashes.size() % 7) * 315u + Offset % 377));
+ kChunkSize + ((BatchHashes.size() % 100) + (BatchHashes.size() % 7) * 315u + Offset % 377));
IoHash Hash = IoHash::HashBuffer(Chunk);
{
RwLock::ExclusiveLockScope __(InsertLock);
diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp
index 0088afe6e..3a7a72ee3 100644
--- a/src/zenstore/filecas.cpp
+++ b/src/zenstore/filecas.cpp
@@ -1231,7 +1231,7 @@ public:
ZEN_MEMSCOPE(GetFileCasTag());
ZEN_TRACE_CPU("FileCas::CompactStore");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -1375,7 +1375,7 @@ public:
ZEN_MEMSCOPE(GetFileCasTag());
ZEN_TRACE_CPU("FileCas::RemoveUnreferencedData");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -1458,7 +1458,7 @@ FileCasStrategy::CreateReferencePruner(GcCtx& Ctx, GcReferenceStoreStats&)
ZEN_TRACE_CPU("FileCas::CreateReferencePruner");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp
index f3edf804d..928fc3f08 100644
--- a/src/zenstore/gc.cpp
+++ b/src/zenstore/gc.cpp
@@ -546,7 +546,7 @@ FilterReferences(GcCtx& Ctx, std::string_view Context, std::vector<IoHash>& InOu
return false;
}
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
const bool Filter = Ctx.Settings.AttachmentRangeMax != IoHash::Max || Ctx.Settings.AttachmentRangeMin != IoHash::Zero;
@@ -2063,6 +2063,14 @@ GcScheduler::GetState() const
{
{
std::unique_lock Lock(m_GcMutex);
+
+ if (m_TriggerGcParams || m_TriggerScrubParams)
+ {
+ // If a trigger is pending, treat it as running
+ Result.Status = GcSchedulerStatus::kRunning;
+ return Result;
+ }
+
Result.LastFullGcTime = m_LastGcTime;
Result.LastFullGCDiff = m_LastFullGCDiff;
Result.LastFullGcDuration = m_LastFullGcDuration;
diff --git a/src/zenstore/include/zenstore/cache/cachepolicy.h b/src/zenstore/include/zenstore/cache/cachepolicy.h
index 7773cd3d1..4a062a0c2 100644
--- a/src/zenstore/include/zenstore/cache/cachepolicy.h
+++ b/src/zenstore/include/zenstore/cache/cachepolicy.h
@@ -163,9 +163,9 @@ private:
friend class CacheRecordPolicyBuilder;
friend class OptionalCacheRecordPolicy;
- CachePolicy RecordPolicy = CachePolicy::Default;
- CachePolicy DefaultValuePolicy = CachePolicy::Default;
- RefPtr<const Private::ICacheRecordPolicyShared> Shared;
+ CachePolicy RecordPolicy = CachePolicy::Default;
+ CachePolicy DefaultValuePolicy = CachePolicy::Default;
+ Ref<const Private::ICacheRecordPolicyShared> Shared;
};
/** A cache record policy builder is used to construct a cache record policy. */
@@ -186,8 +186,8 @@ public:
CacheRecordPolicy Build();
private:
- CachePolicy BasePolicy = CachePolicy::Default;
- RefPtr<Private::ICacheRecordPolicyShared> Shared;
+ CachePolicy BasePolicy = CachePolicy::Default;
+ Ref<Private::ICacheRecordPolicyShared> Shared;
};
/**
diff --git a/src/zenstore/include/zenstore/cidstore.h b/src/zenstore/include/zenstore/cidstore.h
index d54062476..c00e0449f 100644
--- a/src/zenstore/include/zenstore/cidstore.h
+++ b/src/zenstore/include/zenstore/cidstore.h
@@ -58,16 +58,14 @@ struct CidStoreConfiguration
*
*/
-class CidStore final : public ChunkResolver, public StatsProvider
+class CidStore final : public ChunkStore, public StatsProvider
{
public:
CidStore(GcManager& Gc);
~CidStore();
- struct InsertResult
- {
- bool New = false;
- };
+ using InsertResult = ChunkStore::InsertResult;
+
enum class InsertMode
{
kCopyOnly,
@@ -75,17 +73,17 @@ public:
};
void Initialize(const CidStoreConfiguration& Config);
- InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode = InsertMode::kMayBeMovedInPlace);
- std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas,
- std::span<IoHash> RawHashes,
- InsertMode Mode = InsertMode::kMayBeMovedInPlace);
- virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
+ InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) override;
+ InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode);
+ std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) override;
+ std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes, InsertMode Mode);
+ IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
bool IterateChunks(std::span<IoHash> DecompressedIds,
const std::function<bool(size_t Index, const IoBuffer& Payload)>& AsyncCallback,
WorkerThreadPool* OptionalWorkerPool,
uint64_t LargeSizeLimit);
- bool ContainsChunk(const IoHash& DecompressedId);
- void FilterChunks(HashKeySet& InOutChunks);
+ bool ContainsChunk(const IoHash& DecompressedId) override;
+ void FilterChunks(HashKeySet& InOutChunks) override;
void Flush();
CidStoreSize TotalSize() const;
CidStoreStats Stats() const;
diff --git a/src/zenstore/include/zenstore/memorycidstore.h b/src/zenstore/include/zenstore/memorycidstore.h
new file mode 100644
index 000000000..0311274d5
--- /dev/null
+++ b/src/zenstore/include/zenstore/memorycidstore.h
@@ -0,0 +1,68 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "cidstore.h"
+#include "zenstore.h"
+
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/thread.h>
+
+#include <deque>
+#include <span>
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+class HashKeySet;
+
+/** Memory-backed chunk store.
+ *
+ * Stores chunks in an in-memory hash map, optionally layered over a
+ * standard CidStore for write-through and read fallback. When a backing
+ * store is provided:
+ *
+ * - AddChunk writes to memory and asynchronously to the backing store.
+ * - FindChunkByCid checks memory first, then falls back to the backing store.
+ * - ContainsChunk and FilterChunks check memory first, then the backing store.
+ *
+ * The memory store does NOT cache read-through results from the backing store.
+ * Only chunks explicitly added via AddChunk/AddChunks are held in memory.
+ */
+
+class MemoryCidStore : public ChunkStore
+{
+public:
+ explicit MemoryCidStore(CidStore* BackingStore = nullptr);
+ ~MemoryCidStore();
+
+ InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) override;
+ std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) override;
+ IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
+ bool ContainsChunk(const IoHash& DecompressedId) override;
+ void FilterChunks(HashKeySet& InOutChunks) override;
+
+private:
+ RwLock m_Lock;
+ std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> m_Chunks;
+ CidStore* m_BackingStore = nullptr;
+
+ // Async write-through to backing store
+ struct PendingWrite
+ {
+ IoBuffer Data;
+ IoHash Hash;
+ };
+
+ std::mutex m_FlushLock;
+ std::vector<PendingWrite> m_FlushQueue;
+ Event m_FlushEvent;
+ std::thread m_FlushThread;
+ std::atomic<bool> m_FlushThreadEnabled{false};
+
+ void FlushThreadFunction();
+};
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/projectstore.h b/src/zenstore/include/zenstore/projectstore.h
index 100a82907..d05261967 100644
--- a/src/zenstore/include/zenstore/projectstore.h
+++ b/src/zenstore/include/zenstore/projectstore.h
@@ -305,11 +305,11 @@ public:
std::unordered_set<IoHash, IoHash::Hasher> m_PendingPrepOpAttachments;
GcClock::TimePoint m_PendingPrepOpAttachmentsRetainEnd;
- RefPtr<OplogStorage> m_Storage;
- uint64_t m_LogFlushPosition = 0;
- bool m_IsLegacySnapshot = false;
+ Ref<OplogStorage> m_Storage;
+ uint64_t m_LogFlushPosition = 0;
+ bool m_IsLegacySnapshot = false;
- RefPtr<OplogStorage> GetStorage();
+ Ref<OplogStorage> GetStorage();
/** Scan oplog and register each entry, thus updating the in-memory tracking tables
*/
@@ -484,7 +484,7 @@ public:
Project& Project,
Oplog& Oplog,
const std::unordered_set<std::string>& WantedFieldNames);
- static CbObject GetChunkInfo(LoggerRef InLog, Project& Project, Oplog& Oplog, const Oid& ChunkId);
+ static CbObject GetChunkInfo(Project& Project, Oplog& Oplog, const Oid& ChunkId);
struct GetChunkRangeResult
{
enum class EError : uint8_t
@@ -502,8 +502,7 @@ public:
uint64_t RawSize = 0;
ZenContentType ContentType = ZenContentType::kUnknownContentType;
};
- static GetChunkRangeResult GetChunkRange(LoggerRef InLog,
- Project& Project,
+ static GetChunkRangeResult GetChunkRange(Project& Project,
Oplog& Oplog,
const Oid& ChunkId,
uint64_t Offset,
diff --git a/src/zenstore/include/zenstore/zenstore.h b/src/zenstore/include/zenstore/zenstore.h
index bed219b4b..95ae33a4a 100644
--- a/src/zenstore/include/zenstore/zenstore.h
+++ b/src/zenstore/include/zenstore/zenstore.h
@@ -4,19 +4,56 @@
#include <zencore/zencore.h>
+#include <span>
+#include <vector>
+
#define ZENSTORE_API
namespace zen {
+class HashKeySet;
class IoBuffer;
struct IoHash;
class ChunkResolver
{
public:
+ virtual ~ChunkResolver() = default;
virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) = 0;
};
+/** Abstract chunk store interface.
+ *
+ * Extends ChunkResolver with write and query operations. Both CidStore
+ * (disk-backed) and MemoryCidStore (in-memory) implement this interface,
+ * allowing callers to be agnostic about the storage backend.
+ */
+class ChunkStore : public ChunkResolver
+{
+public:
+ struct InsertResult
+ {
+ bool New = false;
+ };
+
+ virtual InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash) = 0;
+ virtual std::vector<InsertResult> AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes) = 0;
+ virtual bool ContainsChunk(const IoHash& DecompressedId) = 0;
+ virtual void FilterChunks(HashKeySet& InOutChunks) = 0;
+};
+
+/** Composite resolver that tries a primary store first, then a fallback. */
+class FallbackChunkResolver : public ChunkResolver
+{
+public:
+ FallbackChunkResolver(ChunkResolver& Primary, ChunkResolver& Fallback);
+ IoBuffer FindChunkByCid(const IoHash& DecompressedId) override;
+
+private:
+ ChunkResolver& m_Primary;
+ ChunkResolver& m_Fallback;
+};
+
ZENSTORE_API void zenstore_forcelinktests();
} // namespace zen
diff --git a/src/zenstore/memorycidstore.cpp b/src/zenstore/memorycidstore.cpp
new file mode 100644
index 000000000..b4832029b
--- /dev/null
+++ b/src/zenstore/memorycidstore.cpp
@@ -0,0 +1,143 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenstore/hashkeyset.h>
+#include <zenstore/memorycidstore.h>
+
+namespace zen {
+
+MemoryCidStore::MemoryCidStore(CidStore* BackingStore) : m_BackingStore(BackingStore)
+{
+ if (m_BackingStore)
+ {
+ m_FlushThreadEnabled = true;
+ m_FlushThread = std::thread(&MemoryCidStore::FlushThreadFunction, this);
+ }
+}
+
+MemoryCidStore::~MemoryCidStore()
+{
+ m_FlushThreadEnabled = false;
+ m_FlushEvent.Set();
+ if (m_FlushThread.joinable())
+ {
+ m_FlushThread.join();
+ }
+}
+
+MemoryCidStore::InsertResult
+MemoryCidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash)
+{
+ bool IsNew = false;
+
+ m_Lock.WithExclusiveLock([&] {
+ auto [It, Inserted] = m_Chunks.try_emplace(RawHash, ChunkData);
+ IsNew = Inserted;
+ });
+
+ if (m_BackingStore)
+ {
+ std::lock_guard<std::mutex> Lock(m_FlushLock);
+ m_FlushQueue.push_back({.Data = ChunkData, .Hash = RawHash});
+ m_FlushEvent.Set();
+ }
+
+ return {.New = IsNew};
+}
+
+std::vector<MemoryCidStore::InsertResult>
+MemoryCidStore::AddChunks(std::span<IoBuffer> ChunkDatas, std::span<IoHash> RawHashes)
+{
+ std::vector<MemoryCidStore::InsertResult> Results;
+ Results.reserve(ChunkDatas.size());
+
+ for (size_t i = 0; i < ChunkDatas.size(); ++i)
+ {
+ Results.push_back(AddChunk(ChunkDatas[i], RawHashes[i]));
+ }
+
+ return Results;
+}
+
+IoBuffer
+MemoryCidStore::FindChunkByCid(const IoHash& DecompressedId)
+{
+ IoBuffer Result;
+
+ m_Lock.WithSharedLock([&] {
+ auto It = m_Chunks.find(DecompressedId);
+ if (It != m_Chunks.end())
+ {
+ Result = It->second;
+ }
+ });
+
+ if (!Result && m_BackingStore)
+ {
+ Result = m_BackingStore->FindChunkByCid(DecompressedId);
+ }
+
+ return Result;
+}
+
+bool
+MemoryCidStore::ContainsChunk(const IoHash& DecompressedId)
+{
+ bool Found = false;
+
+ m_Lock.WithSharedLock([&] { Found = m_Chunks.find(DecompressedId) != m_Chunks.end(); });
+
+ if (!Found && m_BackingStore)
+ {
+ Found = m_BackingStore->ContainsChunk(DecompressedId);
+ }
+
+ return Found;
+}
+
+void
+MemoryCidStore::FilterChunks(HashKeySet& InOutChunks)
+{
+ // Remove hashes that are present in our memory store
+ m_Lock.WithSharedLock([&] { InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return m_Chunks.find(Hash) != m_Chunks.end(); }); });
+
+ // Delegate remainder to backing store
+ if (m_BackingStore && !InOutChunks.IsEmpty())
+ {
+ m_BackingStore->FilterChunks(InOutChunks);
+ }
+}
+
+void
+MemoryCidStore::FlushThreadFunction()
+{
+ SetCurrentThreadName("MemCidFlush");
+
+ while (m_FlushThreadEnabled)
+ {
+ m_FlushEvent.Wait();
+
+ std::vector<PendingWrite> Batch;
+ {
+ std::lock_guard<std::mutex> Lock(m_FlushLock);
+ Batch.swap(m_FlushQueue);
+ }
+
+ for (PendingWrite& Write : Batch)
+ {
+ m_BackingStore->AddChunk(Write.Data, Write.Hash);
+ }
+ }
+
+ // Drain remaining writes on shutdown
+ std::vector<PendingWrite> Remaining;
+ {
+ std::lock_guard<std::mutex> Lock(m_FlushLock);
+ Remaining.swap(m_FlushQueue);
+ }
+ for (PendingWrite& Write : Remaining)
+ {
+ m_BackingStore->AddChunk(Write.Data, Write.Hash);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp
index 13674da4d..8159f9f83 100644
--- a/src/zenstore/projectstore.cpp
+++ b/src/zenstore/projectstore.cpp
@@ -3180,6 +3180,7 @@ ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&,
}
else
{
+ m_ChunkMap.erase(FileId);
Entry.ServerPath = ServerPath;
}
@@ -3402,11 +3403,11 @@ ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage)
return EntryId;
}
-RefPtr<ProjectStore::OplogStorage>
+Ref<ProjectStore::OplogStorage>
ProjectStore::Oplog::GetStorage()
{
ZEN_MEMSCOPE(GetProjectstoreTag());
- RefPtr<OplogStorage> Storage;
+ Ref<OplogStorage> Storage;
{
RwLock::SharedLockScope _(m_OplogLock);
Storage = m_Storage;
@@ -3424,7 +3425,7 @@ ProjectStore::Oplog::AppendNewOplogEntry(CbObjectView Core)
using namespace std::literals;
- RefPtr<OplogStorage> Storage = GetStorage();
+ Ref<OplogStorage> Storage = GetStorage();
if (!Storage)
{
return {};
@@ -3456,7 +3457,7 @@ ProjectStore::Oplog::AppendNewOplogEntries(std::span<CbObjectView> Cores)
using namespace std::literals;
- RefPtr<OplogStorage> Storage = GetStorage();
+ Ref<OplogStorage> Storage = GetStorage();
if (!Storage)
{
return std::vector<ProjectStore::LogSequenceNumber>(Cores.size(), LogSequenceNumber{});
@@ -3515,7 +3516,18 @@ ProjectStore::Project::~Project()
// Only write access times if we have not been explicitly deleted
if (!m_OplogStoragePath.empty())
{
- WriteAccessTimes();
+ try
+ {
+ WriteAccessTimes();
+ }
+ catch (const std::exception& Ex)
+ {
+ // RefCounted::Release() is noexcept, so a destructor that propagates an exception
+ // terminates the program. WriteAccessTimes() already catches I/O failures, but lock
+ // acquisition and allocations ahead of the internal try/catch can still throw, so
+ // we defend here as well.
+ ZEN_ERROR("project '{}': ~Project threw exception: '{}'", Identifier, Ex.what());
+ }
}
}
@@ -4738,7 +4750,7 @@ ProjectStore::GetProjectsList()
CbObject
ProjectStore::GetProjectFiles(LoggerRef InLog, Project& Project, Oplog& Oplog, const std::unordered_set<std::string>& WantedFieldNames)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -4893,7 +4905,7 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl
ZEN_MEMSCOPE(GetProjectstoreTag());
ZEN_TRACE_CPU("ProjectStore::GetProjectChunkInfos");
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -5050,16 +5062,13 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl
}
CbObject
-ProjectStore::GetChunkInfo(LoggerRef InLog, Project& Project, Oplog& Oplog, const Oid& ChunkId)
+ProjectStore::GetChunkInfo(Project& Project, Oplog& Oplog, const Oid& ChunkId)
{
ZEN_MEMSCOPE(GetProjectstoreTag());
ZEN_TRACE_CPU("ProjectStore::GetChunkInfo");
using namespace std::literals;
- auto Log = [&InLog]() { return InLog; };
- ZEN_UNUSED(Log);
-
IoBuffer Chunk = Oplog.FindChunk(Project.RootDir, ChunkId, nullptr);
if (!Chunk)
{
@@ -5168,7 +5177,10 @@ ExtractRange(IoBuffer&& Chunk, uint64_t Offset, uint64_t Size, ZenContentType Ac
const bool IsFullRange = (Offset == 0) && ((Size == ~(0ull)) || (Size == ChunkSize));
if (IsFullRange)
{
- Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk)));
+ if (ChunkSize > 0)
+ {
+ Result.Chunk = CompositeBuffer(SharedBuffer(std::move(Chunk)));
+ }
Result.RawSize = 0;
}
else
@@ -5205,8 +5217,7 @@ ExtractRange(IoBuffer&& Chunk, uint64_t Offset, uint64_t Size, ZenContentType Ac
}
ProjectStore::GetChunkRangeResult
-ProjectStore::GetChunkRange(LoggerRef InLog,
- Project& Project,
+ProjectStore::GetChunkRange(Project& Project,
Oplog& Oplog,
const Oid& ChunkId,
uint64_t Offset,
@@ -5218,9 +5229,6 @@ ProjectStore::GetChunkRange(LoggerRef InLog,
ZEN_TRACE_CPU("ProjectStore::GetChunkRange");
- auto Log = [&InLog]() { return InLog; };
- ZEN_UNUSED(Log);
-
uint64_t OldTag = OptionalInOutModificationTag == nullptr ? 0 : *OptionalInOutModificationTag;
IoBuffer Chunk = Oplog.FindChunk(Project.RootDir, ChunkId, OptionalInOutModificationTag);
if (!Chunk)
@@ -5727,7 +5735,7 @@ public:
ZEN_TRACE_CPU("Store::CompactStore");
ZEN_MEMSCOPE(GetProjectstoreTag());
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -5863,7 +5871,7 @@ ProjectStore::RemoveExpiredData(GcCtx& Ctx, GcStats& Stats)
ZEN_TRACE_CPU("Store::RemoveExpiredData");
ZEN_MEMSCOPE(GetProjectstoreTag());
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -6016,7 +6024,7 @@ public:
{
ZEN_TRACE_CPU("Store::UpdateLockedState");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
@@ -6093,7 +6101,7 @@ public:
{
ZEN_TRACE_CPU("Store::GetUnusedReferences");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
size_t InitialCount = IoCids.size();
size_t UsedCount = InitialCount;
@@ -6117,6 +6125,7 @@ public:
}
private:
+ LoggerRef Log() { return m_ProjectStore.Log(); }
ProjectStore& m_ProjectStore;
std::vector<IoHash> m_References;
};
@@ -6161,7 +6170,7 @@ public:
{
ZEN_TRACE_CPU("Store::Oplog::PreCache");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -6279,7 +6288,7 @@ public:
{
ZEN_TRACE_CPU("Store::Oplog::UpdateLockedState");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
Stopwatch Timer;
const auto _ = MakeGuard([&] {
@@ -6387,7 +6396,7 @@ public:
{
ZEN_TRACE_CPU("Store::Oplog::GetUnusedReferences");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
const size_t InitialCount = IoCids.size();
size_t UsedCount = InitialCount;
@@ -6413,6 +6422,7 @@ public:
return UnusedReferences;
}
+ LoggerRef Log() { return m_Project->Log(); }
ProjectStore& m_ProjectStore;
Ref<ProjectStore::Project> m_Project;
std::string m_OplogId;
@@ -6428,7 +6438,7 @@ ProjectStore::CreateReferenceCheckers(GcCtx& Ctx)
{
ZEN_TRACE_CPU("Store::CreateReferenceCheckers");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
size_t ProjectCount = 0;
size_t OplogCount = 0;
@@ -6490,11 +6500,9 @@ ProjectStore::CreateReferenceCheckers(GcCtx& Ctx)
std::vector<RwLock::SharedLockScope>
ProjectStore::LockState(GcCtx& Ctx)
{
+ ZEN_UNUSED(Ctx);
ZEN_TRACE_CPU("Store::LockState");
- auto Log = [&Ctx]() { return Ctx.Logger; };
- ZEN_UNUSED(Log);
-
std::vector<RwLock::SharedLockScope> Locks;
Locks.emplace_back(RwLock::SharedLockScope(m_ProjectsLock));
for (auto& ProjectIt : m_Projects)
@@ -6526,7 +6534,7 @@ public:
{
ZEN_TRACE_CPU("Store::Validate");
- auto Log = [&Ctx]() { return Ctx.Logger; };
+ ZEN_SCOPED_LOG(Ctx.Logger);
ProjectStore::Oplog::ValidationResult Result;
@@ -6625,9 +6633,6 @@ ProjectStore::CreateReferenceValidators(GcCtx& Ctx)
return {};
}
- auto Log = [&Ctx]() { return Ctx.Logger; };
- ZEN_UNUSED(Log);
-
DiscoverProjects();
std::vector<std::pair<std::string, std::string>> Oplogs;
@@ -8242,8 +8247,7 @@ TEST_CASE("project.store.partial.read")
{
uint64_t ModificationTag = 0;
- auto Result = ProjectStore.GetChunkRange(Log(),
- *Project1,
+ auto Result = ProjectStore.GetChunkRange(*Project1,
*Oplog1,
Attachments[OpIds[1]][0].first,
0,
@@ -8258,8 +8262,7 @@ TEST_CASE("project.store.partial.read")
CompressedBuffer Attachment = CompressedBuffer::FromCompressed(Result.Chunk, RawHash, RawSize);
CHECK(RawSize == Attachments[OpIds[1]][0].second.DecodeRawSize());
- auto Result2 = ProjectStore.GetChunkRange(Log(),
- *Project1,
+ auto Result2 = ProjectStore.GetChunkRange(*Project1,
*Oplog1,
Attachments[OpIds[1]][0].first,
0,
@@ -8272,8 +8275,7 @@ TEST_CASE("project.store.partial.read")
{
uint64_t FullChunkModificationTag = 0;
{
- auto Result = ProjectStore.GetChunkRange(Log(),
- *Project1,
+ auto Result = ProjectStore.GetChunkRange(*Project1,
*Oplog1,
Attachments[OpIds[2]][1].first,
0,
@@ -8286,8 +8288,7 @@ TEST_CASE("project.store.partial.read")
Attachments[OpIds[2]][1].second.DecodeRawSize());
}
{
- auto Result = ProjectStore.GetChunkRange(Log(),
- *Project1,
+ auto Result = ProjectStore.GetChunkRange(*Project1,
*Oplog1,
Attachments[OpIds[2]][1].first,
0,
@@ -8300,8 +8301,7 @@ TEST_CASE("project.store.partial.read")
{
uint64_t PartialChunkModificationTag = 0;
{
- auto Result = ProjectStore.GetChunkRange(Log(),
- *Project1,
+ auto Result = ProjectStore.GetChunkRange(*Project1,
*Oplog1,
Attachments[OpIds[2]][1].first,
5,
@@ -8324,8 +8324,7 @@ TEST_CASE("project.store.partial.read")
}
{
- auto Result = ProjectStore.GetChunkRange(Log(),
- *Project1,
+ auto Result = ProjectStore.GetChunkRange(*Project1,
*Oplog1,
Attachments[OpIds[2]][1].first,
0,
diff --git a/src/zenstore/workspaces.cpp b/src/zenstore/workspaces.cpp
index ad21bbc68..cfdcd294c 100644
--- a/src/zenstore/workspaces.cpp
+++ b/src/zenstore/workspaces.cpp
@@ -331,9 +331,6 @@ ScanFolder(LoggerRef InLog, const std::filesystem::path& Path, WorkerThreadPool&
{
ZEN_TRACE_CPU("workspaces::ScanFolderImpl");
- auto Log = [&InLog]() { return InLog; };
- ZEN_UNUSED(Log);
-
FolderScanner Data(InLog, WorkerPool, Path);
Data.Traverse();
return std::make_unique<FolderStructure>(std::move(Data.FoundFiles), std::move(Data.FoundFileIds));
@@ -811,7 +808,7 @@ Workspaces::GetShareAlias(std::string_view Alias) const
std::vector<Workspaces::WorkspaceConfiguration>
Workspaces::ReadConfig(const LoggerRef& InLog, const std::filesystem::path& WorkspaceStatePath, std::string& OutError)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -835,7 +832,7 @@ Workspaces::WriteConfig(const LoggerRef& InLog,
const std::filesystem::path& WorkspaceStatePath,
const std::vector<WorkspaceConfiguration>& WorkspaceConfigurations)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -850,7 +847,7 @@ Workspaces::WriteConfig(const LoggerRef& InLog,
std::vector<Workspaces::WorkspaceShareConfiguration>
Workspaces::ReadWorkspaceConfig(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, std::string& OutError)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -874,7 +871,7 @@ Workspaces::WriteWorkspaceConfig(const LoggerRef& InLog,
const std::filesystem::path& WorkspaceRoot,
const std::vector<WorkspaceShareConfiguration>& WorkspaceShareConfigurations)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
using namespace std::literals;
@@ -1049,9 +1046,6 @@ Workspaces::RemoveWorkspaceShare(const LoggerRef& Log, const std::filesystem::pa
Workspaces::WorkspaceConfiguration
Workspaces::FindWorkspace(const LoggerRef& InLog, const std::filesystem::path& WorkspaceStatePath, const Oid& WorkspaceId)
{
- auto Log = [&InLog]() { return InLog; };
- ZEN_UNUSED(Log);
-
std::string Error;
std::vector<WorkspaceConfiguration> Workspaces = ReadConfig(InLog, WorkspaceStatePath, Error);
if (!Error.empty())
@@ -1075,9 +1069,6 @@ Workspaces::FindWorkspace(const LoggerRef& InLog,
const std::filesystem::path& WorkspaceStatePath,
const std::filesystem::path& WorkspaceRoot)
{
- auto Log = [&InLog]() { return InLog; };
- ZEN_UNUSED(Log);
-
std::string Error;
std::vector<WorkspaceConfiguration> Workspaces = ReadConfig(InLog, WorkspaceStatePath, Error);
if (!Error.empty())
@@ -1102,7 +1093,7 @@ Workspaces::FindWorkspaceShare(const LoggerRef& InLog,
std::string_view ShareAlias,
WorkspaceConfiguration& OutWorkspace)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
std::string Error;
std::vector<WorkspaceConfiguration> Workspaces = ReadConfig(InLog, WorkspaceStatePath, Error);
@@ -1151,7 +1142,7 @@ Workspaces::FindWorkspaceShare(const LoggerRef& InLog,
Workspaces::WorkspaceShareConfiguration
Workspaces::FindWorkspaceShare(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, const Oid& WorkspaceShareId)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
std::string Error;
std::vector<WorkspaceShareConfiguration> Shares = ReadWorkspaceConfig(InLog, WorkspaceRoot, Error);
if (!Error.empty())
@@ -1174,7 +1165,7 @@ Workspaces::FindWorkspaceShare(const LoggerRef& InLog, const std::filesystem::pa
Workspaces::WorkspaceShareConfiguration
Workspaces::FindWorkspaceShare(const LoggerRef& InLog, const std::filesystem::path& WorkspaceRoot, const std::filesystem::path& SharePath)
{
- auto Log = [&InLog]() { return InLog; };
+ ZEN_SCOPED_LOG(InLog);
std::string Error;
std::vector<WorkspaceShareConfiguration> Shares = ReadWorkspaceConfig(InLog, WorkspaceRoot, Error);
if (!Error.empty())
diff --git a/src/zenstore/zenstore.cpp b/src/zenstore/zenstore.cpp
index c563cc202..bf0c71211 100644
--- a/src/zenstore/zenstore.cpp
+++ b/src/zenstore/zenstore.cpp
@@ -2,6 +2,26 @@
#include "zenstore/zenstore.h"
+#include <zencore/iobuffer.h>
+
+namespace zen {
+
+FallbackChunkResolver::FallbackChunkResolver(ChunkResolver& Primary, ChunkResolver& Fallback) : m_Primary(Primary), m_Fallback(Fallback)
+{
+}
+
+IoBuffer
+FallbackChunkResolver::FindChunkByCid(const IoHash& DecompressedId)
+{
+ if (IoBuffer Result = m_Primary.FindChunkByCid(DecompressedId))
+ {
+ return Result;
+ }
+ return m_Fallback.FindChunkByCid(DecompressedId);
+}
+
+} // namespace zen
+
#if ZEN_WITH_TESTS
# include <zenstore/blockstore.h>
diff --git a/src/zentelemetry/include/zentelemetry/stats.h b/src/zentelemetry/include/zentelemetry/stats.h
index 260b0fcfb..ddec8e883 100644
--- a/src/zentelemetry/include/zentelemetry/stats.h
+++ b/src/zentelemetry/include/zentelemetry/stats.h
@@ -40,7 +40,7 @@ private:
* Suitable for tracking quantities that go up and down over time, such as
* requests in flight or active jobs. All operations are lock-free via atomics.
*
- * Unlike a Meter, a Counter does not track rates — it only records a running total.
+ * Unlike a Meter, a Counter does not track rates - it only records a running total.
*/
class Counter
{
@@ -63,7 +63,7 @@ private:
* rate = rate + alpha * (instantRate - rate)
*
* where instantRate = Count / Interval. The alpha value controls how quickly
- * the average responds to changes — higher alpha means more weight on recent
+ * the average responds to changes - higher alpha means more weight on recent
* samples. Typical alphas are derived from a decay half-life (e.g. 1, 5, 15
* minutes) and a fixed tick interval.
*
@@ -205,7 +205,7 @@ private:
*
* Records exact min, max, count and mean across all values ever seen, plus a
* reservoir sample (via UniformSample) used to compute percentiles. Percentiles
- * are therefore probabilistic — they reflect the distribution of a representative
+ * are therefore probabilistic - they reflect the distribution of a representative
* sample rather than the full history.
*
* All operations are thread-safe via lock-free atomics.
@@ -343,7 +343,7 @@ struct RequestStatsSnapshot
*
* Maintains two independent histogram+meter pairs: one for request duration
* (in hi-freq timer ticks) and one for transferred bytes. Both dimensions
- * share the same request count — a single Update() call advances both.
+ * share the same request count - a single Update() call advances both.
*
* Duration accessors return values in hi-freq timer ticks. Multiply by
* GetHifreqTimerToSeconds() to convert to seconds.
@@ -383,7 +383,7 @@ public:
* or destruction.
*
* The byte count can be supplied at construction or updated at any point via
- * SetBytes() before the scope ends — useful when the response size is not
+ * SetBytes() before the scope ends - useful when the response size is not
* known until the operation completes.
*
* Call Cancel() to discard the measurement entirely.
diff --git a/src/zentelemetry/otelmetricsprotozero.h b/src/zentelemetry/otelmetricsprotozero.h
index 12bae0d75..f02eca7ab 100644
--- a/src/zentelemetry/otelmetricsprotozero.h
+++ b/src/zentelemetry/otelmetricsprotozero.h
@@ -49,22 +49,22 @@ option go_package = "go.opentelemetry.io/proto/otlp/metrics/v1";
// data but do not implement the OTLP protocol.
//
// MetricsData
-// └─── ResourceMetrics
-// ├── Resource
-// ├── SchemaURL
-// └── ScopeMetrics
-// ├── Scope
-// ├── SchemaURL
-// └── Metric
-// ├── Name
-// ├── Description
-// ├── Unit
-// └── data
-// ├── Gauge
-// ├── Sum
-// ├── Histogram
-// ├── ExponentialHistogram
-// └── Summary
+// +--- ResourceMetrics
+// |-- Resource
+// |-- SchemaURL
+// +-- ScopeMetrics
+// |-- Scope
+// |-- SchemaURL
+// +-- Metric
+// |-- Name
+// |-- Description
+// |-- Unit
+// +-- data
+// |-- Gauge
+// |-- Sum
+// |-- Histogram
+// |-- ExponentialHistogram
+// +-- Summary
//
// The main difference between this message and collector protocol is that
// in this message there will not be any "control" or "metadata" specific to
diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp
index 54e54edde..73cb7ff2d 100644
--- a/src/zentest-appstub/zentest-appstub.cpp
+++ b/src/zentest-appstub/zentest-appstub.cpp
@@ -33,6 +33,13 @@ using namespace zen;
// Some basic functions to implement some test "compute" functions
+struct ForcedExitException : std::exception
+{
+ int Code;
+ explicit ForcedExitException(int InCode) : Code(InCode) {}
+ const char* what() const noexcept override { return "forced exit"; }
+};
+
std::string
Rot13Function(std::string_view InputString)
{
@@ -111,6 +118,16 @@ DescribeFunctions()
<< "Sleep"sv;
Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv);
Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Fail"sv;
+ Versions << "Version"sv << Guid::FromString("fa11fa11-fa11-fa11-fa11-fa11fa11fa11"sv);
+ Versions.EndObject();
+ Versions.BeginObject();
+ Versions << "Name"sv
+ << "Crash"sv;
+ Versions << "Version"sv << Guid::FromString("c4a50000-c4a5-c4a5-c4a5-c4a5c4a5c4a5"sv);
+ Versions.EndObject();
Versions.EndArray();
return Versions.Save();
@@ -201,6 +218,38 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver)
zen::Sleep(static_cast<int>(SleepTimeMs));
return Apply(IdentityFunction);
}
+ else if (Function == "Fail"sv)
+ {
+ int FailExitCode = static_cast<int>(Action["Constants"sv].AsObjectView()["ExitCode"sv].AsUInt64());
+ if (FailExitCode == 0)
+ {
+ FailExitCode = 1;
+ }
+ throw ForcedExitException(FailExitCode);
+ }
+ else if (Function == "Crash"sv)
+ {
+ // Crash modes:
+ // "abort" - calls std::abort() (SIGABRT / process termination)
+ // "nullptr" - dereferences a null pointer (SIGSEGV / access violation)
+ std::string_view Mode = Action["Constants"sv].AsObjectView()["Mode"sv].AsString();
+
+ printf("[zentest] crashing with mode: %.*s\n", int(Mode.size()), Mode.data());
+ fflush(stdout);
+
+ if (Mode == "nullptr"sv)
+ {
+ volatile int* Ptr = nullptr;
+ *Ptr = 42;
+ }
+
+ // Default crash mode (also reached after nullptr write on platforms
+ // that don't immediately fault on null dereference)
+#if defined(_MSC_VER)
+ _set_abort_behavior(0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT);
+#endif
+ std::abort();
+ }
else
{
return {};
@@ -421,6 +470,12 @@ main(int argc, char* argv[])
}
}
}
+ catch (ForcedExitException& Ex)
+ {
+ printf("[zentest] forced exit with code: %d\n", Ex.Code);
+
+ ExitCode = Ex.Code;
+ }
catch (std::exception& Ex)
{
printf("[zentest] exception caught in main: '%s'\n", Ex.what());
diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp
index dde1dc019..b025eb6da 100644
--- a/src/zenutil/cloud/imdscredentials.cpp
+++ b/src/zenutil/cloud/imdscredentials.cpp
@@ -115,7 +115,7 @@ ImdsCredentialProvider::FetchToken()
HttpClient::KeyValueMap Headers;
Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600");
- HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers);
+ HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", IoBuffer{}, Headers);
if (!Response.IsSuccess())
{
ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token"));
@@ -213,7 +213,7 @@ ImdsCredentialProvider::FetchCredentials()
}
else
{
- // Expiration is in the past or unparseable — force refresh next time
+ // Expiration is in the past or unparseable - force refresh next time
NewExpiresAt = std::chrono::steady_clock::now();
}
@@ -369,7 +369,7 @@ TEST_CASE("imdscredentials.fetch_from_mock")
TEST_CASE("imdscredentials.unreachable_endpoint")
{
- // Point at a non-existent server — should return empty credentials, not crash
+ // Point at a non-existent server - should return empty credentials, not crash
ImdsCredentialProviderOptions Opts;
Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening
Opts.ConnectTimeout = std::chrono::milliseconds(100);
diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp
index 457453bd8..2db0010dc 100644
--- a/src/zenutil/cloud/minioprocess.cpp
+++ b/src/zenutil/cloud/minioprocess.cpp
@@ -45,7 +45,7 @@ struct MinioProcess::Impl
}
CreateProcOptions Options;
- Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser);
Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword);
@@ -102,7 +102,7 @@ struct MinioProcess::Impl
{
if (m_DataDir.empty())
{
- ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized — call SpawnMinioServer() first");
+ ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized - call SpawnMinioServer() first");
return;
}
diff --git a/src/zenutil/cloud/mockimds.cpp b/src/zenutil/cloud/mockimds.cpp
index 6919fab4d..88b348ed6 100644
--- a/src/zenutil/cloud/mockimds.cpp
+++ b/src/zenutil/cloud/mockimds.cpp
@@ -93,7 +93,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request)
return;
}
- // Autoscaling lifecycle state — 404 when not in an ASG
+ // Autoscaling lifecycle state - 404 when not in an ASG
if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state")
{
if (Aws.AutoscalingState.empty())
@@ -105,7 +105,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request)
return;
}
- // Spot interruption notice — 404 when no interruption pending
+ // Spot interruption notice - 404 when no interruption pending
if (Uri == "latest/meta-data/spot/instance-action")
{
if (Aws.SpotAction.empty())
@@ -117,7 +117,7 @@ MockImdsService::HandleAwsRequest(HttpServerRequest& Request)
return;
}
- // IAM role discovery — returns the role name
+ // IAM role discovery - returns the role name
if (Uri == "latest/meta-data/iam/security-credentials/")
{
if (Aws.IamRoleName.empty())
diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp
index 26d1023f4..83238f5cc 100644
--- a/src/zenutil/cloud/s3client.cpp
+++ b/src/zenutil/cloud/s3client.cpp
@@ -137,6 +137,8 @@ namespace {
} // namespace
+std::string_view S3GetObjectResult::NotFoundErrorText = "Not found";
+
S3Client::S3Client(const S3ClientOptions& Options)
: m_Log(logging::Get("s3"))
, m_BucketName(Options.BucketName)
@@ -145,13 +147,8 @@ S3Client::S3Client(const S3ClientOptions& Options)
, m_PathStyle(Options.PathStyle)
, m_Credentials(Options.Credentials)
, m_CredentialProvider(Options.CredentialProvider)
-, m_HttpClient(BuildEndpoint(),
- HttpClientSettings{
- .LogCategory = "s3",
- .ConnectTimeout = Options.ConnectTimeout,
- .Timeout = Options.Timeout,
- .RetryCount = Options.RetryCount,
- })
+, m_HttpClient(BuildEndpoint(), Options.HttpSettings)
+, m_Verbose(Options.HttpSettings.Verbose)
{
m_Host = BuildHostHeader();
ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})",
@@ -342,26 +339,37 @@ S3Client::PutObject(std::string_view Key, IoBuffer Content)
return S3Result{std::move(Err)};
}
- ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize());
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize());
+ }
return {};
}
S3GetObjectResult
-S3Client::GetObject(std::string_view Key)
+S3Client::GetObject(std::string_view Key, const std::filesystem::path& TempFilePath)
{
std::string Path = KeyToPath(Key);
HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash);
- HttpClient::Response Response = m_HttpClient.Get(Path, Headers);
+ HttpClient::Response Response = m_HttpClient.Download(Path, TempFilePath, Headers);
if (!Response.IsSuccess())
{
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}};
+ }
+
std::string Err = Response.ErrorMessage("S3 GET failed");
ZEN_WARN("S3 GET '{}' failed: {}", Key, Err);
return S3GetObjectResult{S3Result{std::move(Err)}, {}};
}
- ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize());
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize());
+ }
return S3GetObjectResult{{}, std::move(Response.ResponsePayload)};
}
@@ -377,6 +385,11 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran
HttpClient::Response Response = m_HttpClient.Get(Path, Headers);
if (!Response.IsSuccess())
{
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ return S3GetObjectResult{S3Result{.Error = std::string(S3GetObjectResult::NotFoundErrorText)}, {}};
+ }
+
std::string Err = Response.ErrorMessage("S3 GET range failed");
ZEN_WARN("S3 GET range '{}' [{}-{}] failed: {}", Key, RangeStart, RangeStart + RangeSize - 1, Err);
return S3GetObjectResult{S3Result{std::move(Err)}, {}};
@@ -397,11 +410,14 @@ S3Client::GetObjectRange(std::string_view Key, uint64_t RangeStart, uint64_t Ran
return S3GetObjectResult{S3Result{std::move(Err)}, {}};
}
- ZEN_DEBUG("S3 GET range '{}' [{}-{}] succeeded ({} bytes)",
- Key,
- RangeStart,
- RangeStart + RangeSize - 1,
- Response.ResponsePayload.GetSize());
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 GET range '{}' [{}-{}] succeeded ({} bytes)",
+ Key,
+ RangeStart,
+ RangeStart + RangeSize - 1,
+ Response.ResponsePayload.GetSize());
+ }
return S3GetObjectResult{{}, std::move(Response.ResponsePayload)};
}
@@ -420,7 +436,10 @@ S3Client::DeleteObject(std::string_view Key)
return S3Result{std::move(Err)};
}
- ZEN_DEBUG("S3 DELETE '{}' succeeded", Key);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 DELETE '{}' succeeded", Key);
+ }
return {};
}
@@ -462,7 +481,10 @@ S3Client::HeadObject(std::string_view Key)
Info.LastModified = *V;
}
- ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 HEAD '{}' succeeded (size={})", Key, Info.Size);
+ }
return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found};
}
@@ -559,10 +581,16 @@ S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys)
}
ContinuationToken = std::string(NextToken);
- ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size());
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size());
+ }
}
- ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size());
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size());
+ }
return Result;
}
@@ -601,7 +629,10 @@ S3Client::CreateMultipartUpload(std::string_view Key)
return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}};
}
- ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId);
+ }
return S3CreateMultipartUploadResult{{}, std::string(UploadId)};
}
@@ -636,7 +667,10 @@ S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t P
return S3UploadPartResult{S3Result{std::move(Err)}, {}};
}
- ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag);
+ }
return S3UploadPartResult{{}, *ETag};
}
@@ -685,7 +719,10 @@ S3Client::CompleteMultipartUpload(std::string_view Key,
return S3Result{std::move(Err)};
}
- ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size());
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size());
+ }
return {};
}
@@ -706,7 +743,10 @@ S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId)
return S3Result{std::move(Err)};
}
- ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId);
+ }
return {};
}
@@ -749,7 +789,10 @@ S3Client::PutObjectMultipart(std::string_view Key,
return PutObject(Key, TotalSize > 0 ? FetchRange(0, TotalSize) : IoBuffer{});
}
- ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, TotalSize, (TotalSize + PartSize - 1) / PartSize);
+ }
S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key);
if (!InitResult)
@@ -797,7 +840,10 @@ S3Client::PutObjectMultipart(std::string_view Key,
throw;
}
- ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize);
+ if (m_Verbose)
+ {
+ ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), TotalSize);
+ }
return {};
}
@@ -885,7 +931,10 @@ TEST_CASE("s3client.minio_integration")
{
using namespace std::literals;
- // Spawn a local MinIO server
+ // Spawn a single MinIO server for the entire test case. Previously each SUBCASE re-entered
+ // the TEST_CASE from the top, spawning and killing MinIO per subcase - slow and flaky on
+ // macOS CI. Sequential sections avoid the re-entry while still sharing one MinIO instance
+ // that is torn down via RAII at scope exit.
MinioProcessOptions MinioOpts;
MinioOpts.Port = 19000;
MinioOpts.RootUser = "testuser";
@@ -893,11 +942,8 @@ TEST_CASE("s3client.minio_integration")
MinioProcess Minio(MinioOpts);
Minio.SpawnMinioServer();
-
- // Pre-create the test bucket (creates a subdirectory in MinIO's data dir)
Minio.CreateBucket("integration-test");
- // Configure S3Client for the test bucket
S3ClientOptions Opts;
Opts.BucketName = "integration-test";
Opts.Region = "us-east-1";
@@ -908,7 +954,7 @@ TEST_CASE("s3client.minio_integration")
S3Client Client(Opts);
- SUBCASE("put_get_delete")
+ // -- put_get_delete -------------------------------------------------------
{
// PUT
std::string_view TestData = "hello, minio integration test!"sv;
@@ -937,14 +983,14 @@ TEST_CASE("s3client.minio_integration")
CHECK(HeadRes2.Status == HeadObjectResult::NotFound);
}
- SUBCASE("head_not_found")
+ // -- head_not_found -------------------------------------------------------
{
S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat");
CHECK(Res.IsSuccess());
CHECK(Res.Status == HeadObjectResult::NotFound);
}
- SUBCASE("list_objects")
+ // -- list_objects ---------------------------------------------------------
{
// Upload several objects with a common prefix
for (int i = 0; i < 3; ++i)
@@ -979,7 +1025,7 @@ TEST_CASE("s3client.minio_integration")
}
}
- SUBCASE("multipart_upload")
+ // -- multipart_upload -----------------------------------------------------
{
// Create a payload large enough to exercise multipart (use minimum part size)
constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum
@@ -1006,7 +1052,7 @@ TEST_CASE("s3client.minio_integration")
Client.DeleteObject("multipart/large.bin");
}
- SUBCASE("presigned_urls")
+ // -- presigned_urls -------------------------------------------------------
{
// Upload an object
std::string_view TestData = "presigned-url-test-data"sv;
@@ -1032,8 +1078,6 @@ TEST_CASE("s3client.minio_integration")
// Cleanup
Client.DeleteObject("presigned/test.txt");
}
-
- Minio.StopMinioServer();
}
TEST_SUITE_END();
diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp
index 124132aed..10e8abb31 100644
--- a/src/zenutil/consoletui.cpp
+++ b/src/zenutil/consoletui.cpp
@@ -311,7 +311,7 @@ TuiPickOne(std::string_view Title, std::span<const std::string> Items)
printf("\033[1;7m"); // bold + reverse video
}
- // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (▶)
+ // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (>)
const char* Indicator = IsSelected ? " \xe2\x96\xb6 " : " ";
printf("%s%s", Indicator, Items[i].c_str());
@@ -328,7 +328,7 @@ TuiPickOne(std::string_view Title, std::span<const std::string> Items)
printf("\r\033[K\n");
// Hint footer
- // \xe2\x86\x91 = U+2191 ↑ \xe2\x86\x93 = U+2193 ↓
+ // \xe2\x86\x91 = U+2191 ^ \xe2\x86\x93 = U+2193 v
printf(
"\r\033[K \033[2m\xe2\x86\x91/\xe2\x86\x93\033[0m navigate "
"\033[2mEnter\033[0m confirm "
diff --git a/src/zenutil/consul/consul.cpp b/src/zenutil/consul/consul.cpp
index c9144e589..c372b131d 100644
--- a/src/zenutil/consul/consul.cpp
+++ b/src/zenutil/consul/consul.cpp
@@ -9,9 +9,13 @@
#include <zencore/logging.h>
#include <zencore/process.h>
#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
#include <zencore/thread.h>
#include <zencore/timer.h>
+#include <zenhttp/httpserver.h>
+
#include <fmt/format.h>
namespace zen::consul {
@@ -31,7 +35,7 @@ struct ConsulProcess::Impl
}
CreateProcOptions Options;
- Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup;
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
const std::filesystem::path ConsulExe = GetRunningExecutablePath().parent_path() / ("consul" ZEN_EXE_SUFFIX_LITERAL);
CreateProcResult Result = CreateProc(ConsulExe, "consul" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options);
@@ -107,7 +111,7 @@ ConsulProcess::StopConsulAgent()
//////////////////////////////////////////////////////////////////////////
-ConsulClient::ConsulClient(std::string_view BaseUri, std::string_view Token) : m_Token(Token), m_HttpClient(BaseUri)
+ConsulClient::ConsulClient(const Configuration& Config) : m_Config(Config), m_HttpClient(m_Config.BaseUri)
{
}
@@ -193,12 +197,18 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info)
// when no interval is configured (e.g. during Provisioning).
Writer.BeginObject("Check"sv);
{
- Writer.AddString("HTTP"sv, fmt::format("http://{}:{}/{}", Info.Address, Info.Port, Info.HealthEndpoint));
+ Writer.AddString(
+ "HTTP"sv,
+ fmt::format("http://{}:{}/{}", Info.Address.empty() ? "localhost" : Info.Address, Info.Port, Info.HealthEndpoint));
Writer.AddString("Interval"sv, fmt::format("{}s", Info.HealthIntervalSeconds));
if (Info.DeregisterAfterSeconds != 0)
{
Writer.AddString("DeregisterCriticalServiceAfter"sv, fmt::format("{}s", Info.DeregisterAfterSeconds));
}
+ if (!Info.InitialStatus.empty())
+ {
+ Writer.AddString("Status"sv, Info.InitialStatus);
+ }
}
Writer.EndObject(); // Check
}
@@ -223,27 +233,112 @@ ConsulClient::RegisterService(const ServiceRegistrationInfo& Info)
bool
ConsulClient::DeregisterService(std::string_view ServiceId)
{
+ using namespace std::literals;
+
HttpClient::KeyValueMap AdditionalHeaders;
ApplyCommonHeaders(AdditionalHeaders);
AdditionalHeaders.Entries.emplace(HttpClient::Accept(HttpContentType::kJSON));
- HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), AdditionalHeaders);
+ HttpClient::Response Result = m_HttpClient.Put(fmt::format("v1/agent/service/deregister/{}", ServiceId), IoBuffer{}, AdditionalHeaders);
+ if (Result)
+ {
+ return true;
+ }
+
+ // Agent deregister failed - fall back to catalog deregister.
+ // This handles cases where the service was registered via a different Consul agent
+ // (e.g. load-balanced endpoint routing to different agents).
+ std::string NodeName = GetNodeName();
+ if (!NodeName.empty())
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("Node"sv, NodeName);
+ Writer.AddString("ServiceID"sv, ServiceId);
+ ExtendableStringBuilder<256> SB;
+ CompactBinaryToJson(Writer.Save(), SB);
+
+ IoBuffer PayloadBuffer(IoBuffer::Wrap, SB.Data(), SB.Size());
+ PayloadBuffer.SetContentType(HttpContentType::kJSON);
+
+ HttpClient::Response CatalogResult = m_HttpClient.Put("v1/catalog/deregister", PayloadBuffer, AdditionalHeaders);
+ if (CatalogResult)
+ {
+ ZEN_INFO("ConsulClient::DeregisterService() deregistered service '{}' via catalog fallback (agent error: {})",
+ ServiceId,
+ Result.ErrorMessage(""));
+ return true;
+ }
+
+ ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, catalog: {})",
+ ServiceId,
+ Result.ErrorMessage(""),
+ CatalogResult.ErrorMessage(""));
+ }
+ else
+ {
+ ZEN_WARN(
+ "ConsulClient::DeregisterService() failed to deregister service '{}' (agent: {}, could not determine node name for catalog "
+ "fallback)",
+ ServiceId,
+ Result.ErrorMessage(""));
+ }
+
+ return false;
+}
+
+std::string
+ConsulClient::GetNodeName()
+{
+ using namespace std::literals;
+
+ HttpClient::KeyValueMap AdditionalHeaders;
+ ApplyCommonHeaders(AdditionalHeaders);
+
+ HttpClient::Response Result = m_HttpClient.Get("v1/agent/self", AdditionalHeaders);
if (!Result)
{
- ZEN_WARN("ConsulClient::DeregisterService() failed to deregister service '{}' ({})", ServiceId, Result.ErrorMessage(""));
- return false;
+ return {};
}
- return true;
+ std::string JsonError;
+ CbFieldIterator Root = LoadCompactBinaryFromJson(Result.AsText(), JsonError);
+ if (!Root || !JsonError.empty())
+ {
+ return {};
+ }
+
+ for (CbFieldView Field : Root)
+ {
+ if (Field.GetName() == "Config"sv)
+ {
+ CbObjectView Config = Field.AsObjectView();
+ if (Config)
+ {
+ return std::string(Config["NodeName"sv].AsString());
+ }
+ }
+ }
+
+ return {};
}
void
ConsulClient::ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap)
{
- if (!m_Token.empty())
+ std::string Token;
+ if (!m_Config.StaticToken.empty())
{
- InOutHeaderMap.Entries.emplace("X-Consul-Token", m_Token);
+ Token = m_Config.StaticToken;
+ }
+ else if (!m_Config.TokenEnvName.empty())
+ {
+ Token = GetEnvVariable(m_Config.TokenEnvName);
+ }
+
+ if (!Token.empty())
+ {
+ InOutHeaderMap.Entries.emplace("X-Consul-Token", Token);
}
}
@@ -446,4 +541,196 @@ ServiceRegistration::RegistrationLoop()
}
}
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+#if ZEN_WITH_TESTS
+
+void
+consul_forcelink()
+{
+}
+
+struct MockHealthService : public HttpService
+{
+ std::atomic<bool> FailHealth{false};
+ std::atomic<int> HealthCheckCount{0};
+
+ const char* BaseUri() const override { return "/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ std::string_view Uri = Request.RelativeUri();
+ if (Uri == "health/" || Uri == "health")
+ {
+ HealthCheckCount.fetch_add(1);
+ if (FailHealth.load())
+ {
+ Request.WriteResponse(HttpResponseCode::ServiceUnavailable);
+ }
+ else
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok");
+ }
+ return;
+ }
+ Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+};
+
+struct TestHealthServer
+{
+ MockHealthService Mock;
+
+ void Start()
+ {
+ m_TmpDir.emplace();
+ m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"});
+ m_Port = m_Server->Initialize(0, m_TmpDir->Path() / "http");
+ REQUIRE(m_Port != -1);
+ m_Server->RegisterService(Mock);
+ m_ServerThread = std::thread([this]() { m_Server->Run(false); });
+ }
+
+ int Port() const { return m_Port; }
+
+ ~TestHealthServer()
+ {
+ if (m_Server)
+ {
+ m_Server->RequestExit();
+ }
+ if (m_ServerThread.joinable())
+ {
+ m_ServerThread.join();
+ }
+ if (m_Server)
+ {
+ m_Server->Close();
+ }
+ }
+
+private:
+ std::optional<ScopedTemporaryDirectory> m_TmpDir;
+ Ref<HttpServer> m_Server;
+ std::thread m_ServerThread;
+ int m_Port = -1;
+};
+
+static bool
+WaitForCondition(std::function<bool()> Predicate, int TimeoutMs, int PollIntervalMs = 200)
+{
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < static_cast<uint64_t>(TimeoutMs))
+ {
+ if (Predicate())
+ {
+ return true;
+ }
+ Sleep(PollIntervalMs);
+ }
+ return Predicate();
+}
+
+static std::string
+GetCheckStatus(ConsulClient& Client, std::string_view ServiceId)
+{
+ using namespace std::literals;
+
+ std::string JsonError;
+ CbFieldIterator ChecksRoot = LoadCompactBinaryFromJson(Client.GetAgentChecksJson(), JsonError);
+ if (!ChecksRoot || !JsonError.empty())
+ {
+ return {};
+ }
+
+ for (CbFieldView F : ChecksRoot)
+ {
+ if (!F.IsObject())
+ {
+ continue;
+ }
+ for (CbFieldView C : F.AsObjectView())
+ {
+ CbObjectView Check = C.AsObjectView();
+ if (Check["ServiceID"sv].AsString() == ServiceId)
+ {
+ return std::string(Check["Status"sv].AsString());
+ }
+ }
+ }
+ return {};
+}
+
+TEST_SUITE_BEGIN("util.consul");
+
+TEST_CASE("util.consul.service_lifecycle")
+{
+ ConsulProcess ConsulProc;
+ ConsulProc.SpawnConsulAgent();
+
+ TestHealthServer HealthServer;
+ HealthServer.Start();
+
+ ConsulClient Client({.BaseUri = "http://localhost:8500/"});
+
+ const std::string ServiceId = "test-health-svc";
+
+ ServiceRegistrationInfo Info;
+ Info.ServiceId = ServiceId;
+ Info.ServiceName = "zen-test-health";
+ Info.Address = "127.0.0.1";
+ Info.Port = static_cast<uint16_t>(HealthServer.Port());
+ Info.HealthEndpoint = "health/";
+ Info.HealthIntervalSeconds = 1;
+ Info.DeregisterAfterSeconds = 60;
+
+ // Phase 1: Register and verify Consul sends health checks to our service
+ REQUIRE(Client.RegisterService(Info));
+ REQUIRE(Client.HasService(ServiceId));
+
+ REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50));
+ CHECK(HealthServer.Mock.HealthCheckCount.load() >= 1);
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing");
+
+ // Phase 2: Explicit deregister
+ REQUIRE(Client.DeregisterService(ServiceId));
+ CHECK_FALSE(Client.HasService(ServiceId));
+
+ // Phase 3: Register with InitialStatus, verify immediately passing before any health check fires,
+ // then fail health and verify check goes critical
+ HealthServer.Mock.HealthCheckCount.store(0);
+ HealthServer.Mock.FailHealth.store(false);
+
+ Info.InitialStatus = "passing";
+ REQUIRE(Client.RegisterService(Info));
+ REQUIRE(Client.HasService(ServiceId));
+
+ CHECK_EQ(HealthServer.Mock.HealthCheckCount.load(), 0);
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing");
+
+ REQUIRE(WaitForCondition([&]() { return HealthServer.Mock.HealthCheckCount.load() >= 1; }, 10000, 50));
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "passing");
+
+ HealthServer.Mock.FailHealth.store(true);
+
+ // Wait for Consul to observe the failing check
+ REQUIRE(WaitForCondition([&]() { return GetCheckStatus(Client, ServiceId) == "critical"; }, 10000, 50));
+ CHECK_EQ(GetCheckStatus(Client, ServiceId), "critical");
+
+ // Phase 4: Explicit deregister while critical
+ REQUIRE(Client.DeregisterService(ServiceId));
+ CHECK_FALSE(Client.HasService(ServiceId));
+
+ // Phase 5: Deregister an already-deregistered service - should not crash
+ Client.DeregisterService(ServiceId);
+ CHECK_FALSE(Client.HasService(ServiceId));
+
+ ConsulProc.StopConsulAgent();
+}
+
+TEST_SUITE_END();
+
+#endif
+
} // namespace zen::consul
diff --git a/src/zenremotestore/filesystemutils.cpp b/src/zenutil/filesystemutils.cpp
index fdb2143d8..ccc42a838 100644
--- a/src/zenremotestore/filesystemutils.cpp
+++ b/src/zenutil/filesystemutils.cpp
@@ -1,8 +1,6 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenremotestore/filesystemutils.h>
-
-#include <zenremotestore/chunking/chunkedcontent.h>
+#include <zenutil/filesystemutils.h>
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
@@ -83,117 +81,6 @@ BufferedOpenFile::GetRange(uint64_t Offset, uint64_t Size)
return Result;
}
-ReadFileCache::ReadFileCache(std::atomic<uint64_t>& OpenReadCount,
- std::atomic<uint64_t>& CurrentOpenFileCount,
- std::atomic<uint64_t>& ReadCount,
- std::atomic<uint64_t>& ReadByteCount,
- const std::filesystem::path& Path,
- const ChunkedFolderContent& LocalContent,
- const ChunkedContentLookup& LocalLookup,
- size_t MaxOpenFileCount)
-: m_Path(Path)
-, m_LocalContent(LocalContent)
-, m_LocalLookup(LocalLookup)
-, m_OpenReadCount(OpenReadCount)
-, m_CurrentOpenFileCount(CurrentOpenFileCount)
-, m_ReadCount(ReadCount)
-, m_ReadByteCount(ReadByteCount)
-{
- m_OpenFiles.reserve(MaxOpenFileCount);
-}
-ReadFileCache::~ReadFileCache()
-{
- m_OpenFiles.clear();
-}
-
-CompositeBuffer
-ReadFileCache::GetRange(uint32_t SequenceIndex, uint64_t Offset, uint64_t Size)
-{
- ZEN_TRACE_CPU("ReadFileCache::GetRange");
-
- auto CacheIt =
- std::find_if(m_OpenFiles.begin(), m_OpenFiles.end(), [SequenceIndex](const auto& Lhs) { return Lhs.first == SequenceIndex; });
- if (CacheIt != m_OpenFiles.end())
- {
- if (CacheIt != m_OpenFiles.begin())
- {
- auto CachedFile(std::move(CacheIt->second));
- m_OpenFiles.erase(CacheIt);
- m_OpenFiles.insert(m_OpenFiles.begin(), std::make_pair(SequenceIndex, std::move(CachedFile)));
- }
- CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size);
- return Result;
- }
- const uint32_t LocalPathIndex = m_LocalLookup.SequenceIndexFirstPathIndex[SequenceIndex];
- const std::filesystem::path LocalFilePath = (m_Path / m_LocalContent.Paths[LocalPathIndex]).make_preferred();
- if (Size == m_LocalContent.RawSizes[LocalPathIndex])
- {
- IoBuffer Result = IoBufferBuilder::MakeFromFile(LocalFilePath);
- return CompositeBuffer(SharedBuffer(Result));
- }
- if (m_OpenFiles.size() == m_OpenFiles.capacity())
- {
- m_OpenFiles.pop_back();
- }
- m_OpenFiles.insert(
- m_OpenFiles.begin(),
- std::make_pair(
- SequenceIndex,
- std::make_unique<BufferedOpenFile>(LocalFilePath, m_OpenReadCount, m_CurrentOpenFileCount, m_ReadCount, m_ReadByteCount)));
- CompositeBuffer Result = m_OpenFiles.front().second->GetRange(Offset, Size);
- return Result;
-}
-
-uint32_t
-SetNativeFileAttributes(const std::filesystem::path FilePath, SourcePlatform SourcePlatform, uint32_t Attributes)
-{
-#if ZEN_PLATFORM_WINDOWS
- if (SourcePlatform == SourcePlatform::Windows)
- {
- SetFileAttributesToPath(FilePath, Attributes);
- return Attributes;
- }
- else
- {
- uint32_t CurrentAttributes = GetFileAttributesFromPath(FilePath);
- uint32_t NewAttributes = zen::MakeFileAttributeReadOnly(CurrentAttributes, zen::IsFileModeReadOnly(Attributes));
- if (CurrentAttributes != NewAttributes)
- {
- SetFileAttributesToPath(FilePath, NewAttributes);
- }
- return NewAttributes;
- }
-#endif // ZEN_PLATFORM_WINDOWS
-#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
- if (SourcePlatform != SourcePlatform::Windows)
- {
- zen::SetFileMode(FilePath, Attributes);
- return Attributes;
- }
- else
- {
- uint32_t CurrentMode = zen::GetFileMode(FilePath);
- uint32_t NewMode = zen::MakeFileModeReadOnly(CurrentMode, zen::IsFileAttributeReadOnly(Attributes));
- if (CurrentMode != NewMode)
- {
- zen::SetFileMode(FilePath, NewMode);
- }
- return NewMode;
- }
-#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
-};
-
-uint32_t
-GetNativeFileAttributes(const std::filesystem::path FilePath)
-{
-#if ZEN_PLATFORM_WINDOWS
- return GetFileAttributesFromPath(FilePath);
-#endif // ZEN_PLATFORM_WINDOWS
-#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
- return GetFileMode(FilePath);
-#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
-}
-
bool
IsFileWithRetry(const std::filesystem::path& Path)
{
@@ -256,6 +143,21 @@ RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesyst
}
std::error_code
+RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath)
+{
+ std::error_code Ec;
+ RenameDirectory(SourcePath, TargetPath, Ec);
+ for (size_t Retries = 0; Ec && Retries < 5; Retries++)
+ {
+ ZEN_ASSERT_SLOW(IsDir(SourcePath));
+ Sleep(50 + int(Retries * 150));
+ Ec.clear();
+ RenameDirectory(SourcePath, TargetPath, Ec);
+ }
+ return Ec;
+}
+
+std::error_code
TryRemoveFile(const std::filesystem::path& Path)
{
std::error_code Ec;
@@ -336,6 +238,124 @@ FastCopyFile(bool AllowFileClone,
}
}
+void
+GetDirectoryContent(WorkerThreadPool& WorkerPool,
+ const std::filesystem::path& Path,
+ DirectoryContentFlags Flags,
+ DirectoryContent& OutContent)
+{
+ struct Visitor : public GetDirectoryContentVisitor
+ {
+ Visitor(zen::DirectoryContent& OutContent, const std::filesystem::path& InRootPath) : Content(OutContent), RootPath(InRootPath) {}
+ virtual bool AsyncAllowDirectory(const std::filesystem::path& Parent, const std::filesystem::path& DirectoryName) const
+ {
+ ZEN_UNUSED(Parent, DirectoryName);
+ return true;
+ }
+ virtual void AsyncVisitDirectory(const std::filesystem::path& RelativeRoot, DirectoryContent&& InContent)
+ {
+ std::vector<std::filesystem::path> Files;
+ std::vector<std::filesystem::path> Directories;
+
+ if (!InContent.FileNames.empty())
+ {
+ Files.reserve(InContent.FileNames.size());
+ for (const std::filesystem::path& FileName : InContent.FileNames)
+ {
+ if (RelativeRoot.empty())
+ {
+ Files.push_back(RootPath / FileName);
+ }
+ else
+ {
+ Files.push_back(RootPath / RelativeRoot / FileName);
+ }
+ }
+ }
+
+ if (!InContent.DirectoryNames.empty())
+ {
+ Directories.reserve(InContent.DirectoryNames.size());
+ for (const std::filesystem::path& DirName : InContent.DirectoryNames)
+ {
+ if (RelativeRoot.empty())
+ {
+ Directories.push_back(RootPath / DirName);
+ }
+ else
+ {
+ Directories.push_back(RootPath / RelativeRoot / DirName);
+ }
+ }
+ }
+
+ Lock.WithExclusiveLock([&]() {
+ if (!InContent.FileNames.empty())
+ {
+ for (const std::filesystem::path& FileName : InContent.FileNames)
+ {
+ if (RelativeRoot.empty())
+ {
+ Content.Files.push_back(RootPath / FileName);
+ }
+ else
+ {
+ Content.Files.push_back(RootPath / RelativeRoot / FileName);
+ }
+ }
+ }
+ if (!InContent.FileSizes.empty())
+ {
+ Content.FileSizes.insert(Content.FileSizes.end(), InContent.FileSizes.begin(), InContent.FileSizes.end());
+ }
+ if (!InContent.FileAttributes.empty())
+ {
+ Content.FileAttributes.insert(Content.FileAttributes.end(),
+ InContent.FileAttributes.begin(),
+ InContent.FileAttributes.end());
+ }
+ if (!InContent.FileModificationTicks.empty())
+ {
+ Content.FileModificationTicks.insert(Content.FileModificationTicks.end(),
+ InContent.FileModificationTicks.begin(),
+ InContent.FileModificationTicks.end());
+ }
+
+ if (!InContent.DirectoryNames.empty())
+ {
+ for (const std::filesystem::path& DirName : InContent.DirectoryNames)
+ {
+ if (RelativeRoot.empty())
+ {
+ Content.Directories.push_back(RootPath / DirName);
+ }
+ else
+ {
+ Content.Directories.push_back(RootPath / RelativeRoot / DirName);
+ }
+ }
+ }
+ if (!InContent.DirectoryAttributes.empty())
+ {
+ Content.DirectoryAttributes.insert(Content.DirectoryAttributes.end(),
+ InContent.DirectoryAttributes.begin(),
+ InContent.DirectoryAttributes.end());
+ }
+ });
+ }
+ RwLock Lock;
+ zen::DirectoryContent& Content;
+ const std::filesystem::path& RootPath;
+ };
+
+ Visitor RootVisitor(OutContent, Path);
+
+ Latch PendingWork(1);
+ GetDirectoryContent(Path, Flags, RootVisitor, WorkerPool, PendingWork);
+ PendingWork.CountDown();
+ PendingWork.Wait();
+}
+
CleanDirectoryResult
CleanDirectory(
WorkerThreadPool& IOWorkerPool,
@@ -468,7 +488,7 @@ CleanDirectory(
if (!FailedRemovePaths.empty())
{
RwLock::ExclusiveLockScope _(ResultLock);
- FailedRemovePaths.insert(FailedRemovePaths.end(), FailedRemovePaths.begin(), FailedRemovePaths.end());
+ Result.FailedRemovePaths.insert(Result.FailedRemovePaths.end(), FailedRemovePaths.begin(), FailedRemovePaths.end());
}
else if (!RelativeRoot.empty())
{
@@ -637,7 +657,7 @@ namespace {
void GenerateFile(const std::filesystem::path& Path) { BasicFile _(Path, BasicFile::Mode::kTruncate); }
} // namespace
-TEST_SUITE_BEGIN("remotestore.filesystemutils");
+TEST_SUITE_BEGIN("util.filesystemutils");
TEST_CASE("filesystemutils.CleanDirectory")
{
diff --git a/src/zenutil/include/zenutil/cloud/mockimds.h b/src/zenutil/include/zenutil/cloud/mockimds.h
index d0c0155b0..28e1e8ba6 100644
--- a/src/zenutil/include/zenutil/cloud/mockimds.h
+++ b/src/zenutil/include/zenutil/cloud/mockimds.h
@@ -23,7 +23,7 @@ namespace zen::compute {
*
* When a request arrives for a provider that is not the ActiveProvider, the
* mock returns 404, causing CloudMetadata to write a sentinel file and move
- * on to the next provider — exactly like a failed probe on bare metal.
+ * on to the next provider - exactly like a failed probe on bare metal.
*
* All config fields are public and can be mutated between poll cycles to
* simulate state changes (e.g. a spot interruption appearing mid-run).
@@ -45,13 +45,13 @@ public:
std::string AvailabilityZone = "us-east-1a";
std::string LifeCycle = "on-demand"; // "spot" or "on-demand"
- // Empty string → endpoint returns 404 (instance not in an ASG).
- // Non-empty → returned as the response body. "InService" means healthy;
+ // Empty string -> endpoint returns 404 (instance not in an ASG).
+ // Non-empty -> returned as the response body. "InService" means healthy;
// anything else (e.g. "Terminated:Wait") triggers termination detection.
std::string AutoscalingState;
- // Empty string → endpoint returns 404 (no spot interruption).
- // Non-empty → returned as the response body, signalling a spot reclaim.
+ // Empty string -> endpoint returns 404 (no spot interruption).
+ // Non-empty -> returned as the response body, signalling a spot reclaim.
std::string SpotAction;
// IAM credential fields for ImdsCredentialProvider testing
@@ -69,10 +69,10 @@ public:
std::string Location = "eastus";
std::string Priority = "Regular"; // "Spot" or "Regular"
- // Empty → instance is not in a VM Scale Set (no autoscaling).
+ // Empty -> instance is not in a VM Scale Set (no autoscaling).
std::string VmScaleSetName;
- // Empty → no scheduled events. Set to "Preempt", "Terminate", or
+ // Empty -> no scheduled events. Set to "Preempt", "Terminate", or
// "Reboot" to simulate a termination-class event.
std::string ScheduledEventType;
std::string ScheduledEventStatus = "Scheduled";
diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h
index bd30aa8a2..b0402d231 100644
--- a/src/zenutil/include/zenutil/cloud/s3client.h
+++ b/src/zenutil/include/zenutil/cloud/s3client.h
@@ -35,9 +35,7 @@ struct S3ClientOptions
/// Overrides the static Credentials field.
Ref<ImdsCredentialProvider> CredentialProvider;
- std::chrono::milliseconds ConnectTimeout{5000};
- std::chrono::milliseconds Timeout{};
- uint8_t RetryCount = 3;
+ HttpClientSettings HttpSettings = {.LogCategory = "s3", .ConnectTimeout = std::chrono::milliseconds(5000), .RetryCount = 3};
};
struct S3ObjectInfo
@@ -70,6 +68,8 @@ struct S3GetObjectResult : S3Result
IoBuffer Content;
std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); }
+
+ static std::string_view NotFoundErrorText;
};
/// Result of HeadObject - carries object metadata and existence status.
@@ -119,7 +119,7 @@ public:
S3Result PutObject(std::string_view Key, IoBuffer Content);
/// Download an object from S3
- S3GetObjectResult GetObject(std::string_view Key);
+ S3GetObjectResult GetObject(std::string_view Key, const std::filesystem::path& TempFilePath = {});
/// Download a byte range of an object from S3
/// @param RangeStart First byte offset (inclusive)
@@ -219,6 +219,7 @@ private:
SigV4Credentials m_Credentials;
Ref<ImdsCredentialProvider> m_CredentialProvider;
HttpClient m_HttpClient;
+ bool m_Verbose = false;
// Cached signing key (only changes once per day, protected by RwLock for thread safety)
mutable RwLock m_SigningKeyLock;
diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h
index 22737589b..49bb0cc92 100644
--- a/src/zenutil/include/zenutil/consoletui.h
+++ b/src/zenutil/include/zenutil/consoletui.h
@@ -30,7 +30,7 @@ bool IsTuiAvailable();
// - Title: a short description printed once above the list
// - Items: pre-formatted display labels, one per selectable entry
//
-// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels.
+// Arrow keys (^/v) navigate the selection, Enter confirms, Esc cancels.
// Returns the index of the selected item, or -1 if the user cancelled.
//
// Precondition: IsTuiAvailable() must be true.
diff --git a/src/zenutil/include/zenutil/consul.h b/src/zenutil/include/zenutil/consul.h
index 4002d5d23..c3d0e5f1d 100644
--- a/src/zenutil/include/zenutil/consul.h
+++ b/src/zenutil/include/zenutil/consul.h
@@ -23,12 +23,20 @@ struct ServiceRegistrationInfo
std::vector<std::pair<std::string, std::string>> Tags;
uint32_t HealthIntervalSeconds = 10;
uint32_t DeregisterAfterSeconds = 30;
+ std::string InitialStatus;
};
class ConsulClient
{
public:
- ConsulClient(std::string_view BaseUri, std::string_view Token = "");
+ struct Configuration
+ {
+ std::string BaseUri;
+ std::string StaticToken;
+ std::string TokenEnvName;
+ };
+
+ ConsulClient(const Configuration& Config);
~ConsulClient();
ConsulClient(const ConsulClient&) = delete;
@@ -55,9 +63,10 @@ public:
private:
static bool FindServiceInJson(std::string_view Json, std::string_view ServiceId);
void ApplyCommonHeaders(HttpClient::KeyValueMap& InOutHeaderMap);
+ std::string GetNodeName();
- std::string m_Token;
- HttpClient m_HttpClient;
+ Configuration m_Config;
+ HttpClient m_HttpClient;
};
class ConsulProcess
@@ -109,4 +118,6 @@ private:
void RegistrationLoop();
};
+void consul_forcelink();
+
} // namespace zen::consul
diff --git a/src/zenremotestore/include/zenremotestore/filesystemutils.h b/src/zenutil/include/zenutil/filesystemutils.h
index cb2d718f7..05defd1a8 100644
--- a/src/zenremotestore/include/zenremotestore/filesystemutils.h
+++ b/src/zenutil/include/zenutil/filesystemutils.h
@@ -3,7 +3,7 @@
#pragma once
#include <zencore/basicfile.h>
-#include <zenremotestore/chunking/chunkedcontent.h>
+#include <zencore/filesystem.h>
namespace zen {
@@ -42,42 +42,12 @@ private:
IoBuffer m_Cache;
};
-class ReadFileCache
-{
-public:
- // A buffered file reader that provides CompositeBuffer where the buffers are owned and the memory never overwritten
- ReadFileCache(std::atomic<uint64_t>& OpenReadCount,
- std::atomic<uint64_t>& CurrentOpenFileCount,
- std::atomic<uint64_t>& ReadCount,
- std::atomic<uint64_t>& ReadByteCount,
- const std::filesystem::path& Path,
- const ChunkedFolderContent& LocalContent,
- const ChunkedContentLookup& LocalLookup,
- size_t MaxOpenFileCount);
- ~ReadFileCache();
-
- CompositeBuffer GetRange(uint32_t SequenceIndex, uint64_t Offset, uint64_t Size);
-
-private:
- const std::filesystem::path m_Path;
- const ChunkedFolderContent& m_LocalContent;
- const ChunkedContentLookup& m_LocalLookup;
- std::vector<std::pair<uint32_t, std::unique_ptr<BufferedOpenFile>>> m_OpenFiles;
- std::atomic<uint64_t>& m_OpenReadCount;
- std::atomic<uint64_t>& m_CurrentOpenFileCount;
- std::atomic<uint64_t>& m_ReadCount;
- std::atomic<uint64_t>& m_ReadByteCount;
-};
-
-uint32_t SetNativeFileAttributes(const std::filesystem::path FilePath, SourcePlatform SourcePlatform, uint32_t Attributes);
-
-uint32_t GetNativeFileAttributes(const std::filesystem::path FilePath);
-
bool IsFileWithRetry(const std::filesystem::path& Path);
bool SetFileReadOnlyWithRetry(const std::filesystem::path& Path, bool ReadOnly);
std::error_code RenameFileWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath);
+std::error_code RenameDirectoryWithRetry(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath);
std::error_code TryRemoveFile(const std::filesystem::path& Path);
@@ -101,6 +71,13 @@ struct CleanDirectoryResult
std::vector<std::pair<std::filesystem::path, std::error_code>> FailedRemovePaths;
};
+class WorkerThreadPool;
+
+void GetDirectoryContent(WorkerThreadPool& WorkerPool,
+ const std::filesystem::path& Path,
+ DirectoryContentFlags Flags,
+ DirectoryContent& OutContent);
+
CleanDirectoryResult CleanDirectory(
WorkerThreadPool& IOWorkerPool,
std::atomic<bool>& AbortFlag,
diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h
index 4a25170df..95d7fa43d 100644
--- a/src/zenutil/include/zenutil/process/subprocessmanager.h
+++ b/src/zenutil/include/zenutil/process/subprocessmanager.h
@@ -95,19 +95,24 @@ public:
/// Spawn a new child process and begin monitoring it.
///
/// If Options.StdoutPipe is set, the pipe is consumed and async reading
- /// begins automatically. Similarly for Options.StderrPipe.
+ /// begins automatically. Similarly for Options.StderrPipe. When providing
+ /// pipes, pass the corresponding data callback here so it is installed
+ /// before the first async read completes - setting it later via
+ /// SetStdoutCallback risks losing early output.
///
/// Returns a non-owning pointer valid until Remove() or manager destruction.
/// The exit callback fires on an io_context thread when the process terminates.
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout = {},
+ ProcessDataCallback OnStderr = {});
/// Adopt an already-running process by handle. Takes ownership of handle internals.
ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit);
- /// Stop monitoring a process by pid. Does NOT kill the process — call
+ /// Stop monitoring a process by pid. Does NOT kill the process - call
/// process->Kill() first if needed. The exit callback will not fire after
/// this returns.
void Remove(int Pid);
@@ -182,12 +187,6 @@ public:
/// yet computed.
[[nodiscard]] float GetCpuUsagePercent() const;
- /// Set per-process stdout callback (overrides manager default).
- void SetStdoutCallback(ProcessDataCallback Callback);
-
- /// Set per-process stderr callback (overrides manager default).
- void SetStderrCallback(ProcessDataCallback Callback);
-
/// Return all stdout captured so far. When a callback is set, output is
/// delivered there instead of being accumulated.
[[nodiscard]] std::string GetCapturedStdout() const;
@@ -220,7 +219,7 @@ private:
/// A group of managed processes with OS-level backing.
///
/// On Windows: backed by a JobObject. All processes assigned on spawn.
-/// Kill-on-close guarantee — if the group is destroyed, the OS terminates
+/// Kill-on-close guarantee - if the group is destroyed, the OS terminates
/// all member processes.
/// On Linux/macOS: uses setpgid() so children share a process group.
/// Enables bulk signal delivery via kill(-pgid, sig).
@@ -237,11 +236,14 @@ public:
/// Group name (as passed to CreateGroup).
[[nodiscard]] std::string_view GetName() const;
- /// Spawn a process into this group.
+ /// Spawn a process into this group. See SubprocessManager::Spawn for
+ /// details on the stdout/stderr callback parameters.
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout = {},
+ ProcessDataCallback OnStderr = {});
/// Adopt an already-running process into this group.
/// On Windows the process is assigned to the group's JobObject.
diff --git a/src/zenutil/include/zenutil/progress.h b/src/zenutil/include/zenutil/progress.h
new file mode 100644
index 000000000..6a137ae9c
--- /dev/null
+++ b/src/zenutil/include/zenutil/progress.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logbase.h>
+
+#include <memory>
+#include <string>
+
+namespace zen {
+
+class ProgressBase
+{
+public:
+ virtual ~ProgressBase() = default;
+
+ virtual void SetLogOperationName(std::string_view Name) = 0;
+ virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0;
+ virtual uint32_t GetProgressUpdateDelayMS() const = 0;
+
+ class ProgressBar
+ {
+ public:
+ struct State
+ {
+ bool operator==(const State&) const = default;
+ std::string Task;
+ std::string Details;
+ uint64_t TotalCount = 0;
+ uint64_t RemainingCount = 0;
+ uint64_t OptionalElapsedTime = (uint64_t)-1;
+ enum class EStatus
+ {
+ Running,
+ Aborted,
+ Paused
+ };
+ EStatus Status = EStatus::Running;
+
+ static constexpr EStatus CalculateStatus(bool IsAborted, bool IsPaused)
+ {
+ if (IsAborted)
+ {
+ return EStatus::Aborted;
+ }
+ if (IsPaused)
+ {
+ return EStatus::Paused;
+ }
+ return EStatus::Running;
+ }
+ };
+
+ virtual ~ProgressBar() = default;
+
+ virtual void UpdateState(const State& NewState, bool DoLinebreak) = 0;
+ virtual void Finish() = 0;
+ };
+
+ virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) = 0;
+};
+
+ProgressBase* CreateStandardProgress(LoggerRef Log);
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/sessionsclient.h b/src/zenutil/include/zenutil/sessionsclient.h
index aca45e61d..12ff5e593 100644
--- a/src/zenutil/include/zenutil/sessionsclient.h
+++ b/src/zenutil/include/zenutil/sessionsclient.h
@@ -35,13 +35,13 @@ public:
SessionsServiceClient(const SessionsServiceClient&) = delete;
SessionsServiceClient& operator=(const SessionsServiceClient&) = delete;
- /// POST /sessions/{id} — register or re-announce the session with optional metadata.
+ /// POST /sessions/{id} - register or re-announce the session with optional metadata.
[[nodiscard]] bool Announce(CbObjectView Metadata = {});
- /// PUT /sessions/{id} — update metadata on an existing session.
+ /// PUT /sessions/{id} - update metadata on an existing session.
[[nodiscard]] bool UpdateMetadata(CbObjectView Metadata = {});
- /// DELETE /sessions/{id} — remove the session.
+ /// DELETE /sessions/{id} - remove the session.
[[nodiscard]] bool Remove();
/// Create a logging sink that forwards log messages to the session's log endpoint.
diff --git a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h
index f4ac5ff22..4387e616a 100644
--- a/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h
+++ b/src/zenutil/include/zenutil/splitconsole/tcplogstreamsink.h
@@ -85,12 +85,12 @@ public:
void Flush() override
{
- // Nothing to flush — writes happen asynchronously
+ // Nothing to flush - writes happen asynchronously
}
void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override
{
- // Not used — we output the raw payload directly
+ // Not used - we output the raw payload directly
}
private:
@@ -124,7 +124,7 @@ private:
{
break; // don't retry during shutdown
}
- continue; // drop batch — will retry on next batch
+ continue; // drop batch - will retry on next batch
}
// Build a gathered buffer sequence so the entire batch is written
@@ -176,7 +176,7 @@ private:
std::string m_Source;
uint32_t m_MaxQueueSize;
- // Sequence counter — incremented atomically by Log() callers.
+ // Sequence counter - incremented atomically by Log() callers.
// Gaps in the sequence seen by the receiver indicate dropped messages.
std::atomic<uint64_t> m_NextSequence{0};
diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h
index 03d507400..d6f66fbea 100644
--- a/src/zenutil/include/zenutil/zenserverprocess.h
+++ b/src/zenutil/include/zenutil/zenserverprocess.h
@@ -66,6 +66,7 @@ public:
std::filesystem::path CreateNewTestDir();
std::filesystem::path CreateChildDir(std::string_view ChildName);
std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; }
+ std::filesystem::path GetChildBaseDir() const { return m_ChildProcessBaseDir; }
std::filesystem::path GetTestRootDir(std::string_view Path);
inline bool IsInitialized() const { return m_IsInitialized; }
inline bool IsTestEnvironment() const { return m_IsTestInstance; }
diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp
index 2a4840241..283a8bc37 100644
--- a/src/zenutil/logging/fullformatter.cpp
+++ b/src/zenutil/logging/fullformatter.cpp
@@ -12,6 +12,7 @@
#include <atomic>
#include <chrono>
#include <string>
+#include "zencore/logging.h"
namespace zen::logging {
@@ -25,7 +26,7 @@ struct FullFormatter::Impl
{
}
- explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {}
+ explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' ') {}
std::chrono::time_point<std::chrono::system_clock> m_Epoch;
std::tm m_CachedLocalTm{};
@@ -155,15 +156,7 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer)
OutBuffer.push_back(' ');
}
- // append logger name if exists
- if (Msg.GetLoggerName().size() > 0)
- {
- OutBuffer.push_back('[');
- helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer);
- OutBuffer.push_back(']');
- OutBuffer.push_back(' ');
- }
-
+ // level
OutBuffer.push_back('[');
if (IsColorEnabled())
{
@@ -177,6 +170,23 @@ FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer)
OutBuffer.push_back(']');
OutBuffer.push_back(' ');
+ // logger name
+ if (Msg.GetLoggerName().size() > 0)
+ {
+ if (IsColorEnabled())
+ {
+ OutBuffer.append("\033[1m"sv);
+ }
+ OutBuffer.push_back('[');
+ helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer);
+ OutBuffer.push_back(']');
+ if (IsColorEnabled())
+ {
+ OutBuffer.append("\033[0m"sv);
+ }
+ OutBuffer.push_back(' ');
+ }
+
// add source location if present
if (Msg.GetSource())
{
diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp
index 673a03c94..c63ad891e 100644
--- a/src/zenutil/logging/jsonformatter.cpp
+++ b/src/zenutil/logging/jsonformatter.cpp
@@ -19,8 +19,6 @@ static void
WriteEscapedString(MemoryBuffer& Dest, std::string_view Text)
{
// Strip ANSI SGR sequences before escaping so they don't appear in JSON output
- static const auto IsEscapeStart = [](char C) { return C == '\033'; };
-
const char* RangeStart = Text.data();
const char* End = Text.data() + Text.size();
diff --git a/src/zenutil/logging/logging.cpp b/src/zenutil/logging/logging.cpp
index aa34fc50c..c1636da61 100644
--- a/src/zenutil/logging/logging.cpp
+++ b/src/zenutil/logging/logging.cpp
@@ -124,7 +124,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
LoggerRef DefaultLogger = zen::logging::Default();
- // Build the broadcast sink — a shared indirection point that all
+ // Build the broadcast sink - a shared indirection point that all
// loggers cloned from the default will share. Adding or removing
// a child sink later is immediately visible to every logger.
std::vector<logging::SinkPtr> BroadcastChildren;
@@ -179,7 +179,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions)
{
return;
}
- static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"};
+ static constinit logging::LogPoint ErrorPoint{0, 0, logging::Err, "{}"};
if (auto ErrLogger = zen::logging::ErrorLog())
{
try
@@ -249,7 +249,7 @@ FinishInitializeLogging(const LoggingOptions& LogOptions)
const std::string StartLogTime = zen::DateTime::Now().ToIso8601();
logging::Registry::Instance().ApplyAll([&](auto Logger) {
- static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"};
+ static constinit logging::LogPoint LogStartPoint{0, 0, logging::Info, "log starting at {}"};
Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime));
});
}
diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp
index 23cf60d16..df59af5fe 100644
--- a/src/zenutil/logging/rotatingfilesink.cpp
+++ b/src/zenutil/logging/rotatingfilesink.cpp
@@ -85,7 +85,7 @@ struct RotatingFileSink::Impl
m_CurrentSize = m_CurrentFile.FileSize(OutEc);
if (OutEc)
{
- // FileSize failed but we have an open file — reset to 0
+ // FileSize failed but we have an open file - reset to 0
// so we can at least attempt writes from the start
m_CurrentSize = 0;
OutEc.clear();
diff --git a/src/zenutil/process/asyncpipereader.cpp b/src/zenutil/process/asyncpipereader.cpp
index 2fdcda30d..8eac350c6 100644
--- a/src/zenutil/process/asyncpipereader.cpp
+++ b/src/zenutil/process/asyncpipereader.cpp
@@ -50,7 +50,7 @@ struct AsyncPipeReader::Impl
int Fd = Pipe.ReadFd;
- // Close the write end — child already has it
+ // Close the write end - child already has it
Pipe.CloseWriteEnd();
// Set non-blocking
@@ -156,7 +156,7 @@ CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe)
// The read end should not be inherited by the child
SetHandleInformation(ReadHandle, HANDLE_FLAG_INHERIT, 0);
- // Open the client (write) end — inheritable, for the child process
+ // Open the client (write) end - inheritable, for the child process
SECURITY_ATTRIBUTES Sa;
Sa.nLength = sizeof(Sa);
Sa.lpSecurityDescriptor = nullptr;
@@ -202,7 +202,7 @@ struct AsyncPipeReader::Impl
HANDLE ReadHandle = static_cast<HANDLE>(Pipe.ReadHandle);
- // Close the write end — child already has it
+ // Close the write end - child already has it
Pipe.CloseWriteEnd();
// Take ownership of the read handle
diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp
index 3a91b0a61..d0b912a0d 100644
--- a/src/zenutil/process/subprocessmanager.cpp
+++ b/src/zenutil/process/subprocessmanager.cpp
@@ -196,18 +196,6 @@ ManagedProcess::GetCpuUsagePercent() const
return m_Impl->m_CpuUsagePercent.load();
}
-void
-ManagedProcess::SetStdoutCallback(ProcessDataCallback Callback)
-{
- m_Impl->m_StdoutCallback = std::move(Callback);
-}
-
-void
-ManagedProcess::SetStderrCallback(ProcessDataCallback Callback)
-{
- m_Impl->m_StderrCallback = std::move(Callback);
-}
-
std::string
ManagedProcess::GetCapturedStdout() const
{
@@ -288,7 +276,9 @@ struct SubprocessManager::Impl
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr);
ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit);
void Remove(int Pid);
void RemoveAll();
@@ -462,7 +452,9 @@ ManagedProcess*
SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
bool HasStdout = Options.StdoutPipe != nullptr;
bool HasStderr = Options.StderrPipe != nullptr;
@@ -476,6 +468,16 @@ SubprocessManager::Impl::Spawn(const std::filesystem::path& Executable,
ImplPtr->m_Handle.Initialize(static_cast<int>(Result));
#endif
+ // Install callbacks before starting async readers so no data is missed.
+ if (OnStdout)
+ {
+ ImplPtr->m_StdoutCallback = std::move(OnStdout);
+ }
+ if (OnStderr)
+ {
+ ImplPtr->m_StderrCallback = std::move(OnStderr);
+ }
+
auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr)));
ManagedProcess* Ptr = AddProcess(std::move(Proc));
@@ -719,10 +721,12 @@ ManagedProcess*
SubprocessManager::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
ZEN_TRACE_CPU("SubprocessManager::Spawn");
- return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit));
+ return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr));
}
ManagedProcess*
@@ -835,7 +839,9 @@ struct ProcessGroup::Impl
ManagedProcess* Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit);
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr);
ManagedProcess* Adopt(ProcessHandle&& Handle, ProcessExitCallback OnExit);
void Remove(int Pid);
void KillAll();
@@ -884,7 +890,9 @@ ManagedProcess*
ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
bool HasStdout = Options.StdoutPipe != nullptr;
bool HasStderr = Options.StderrPipe != nullptr;
@@ -895,7 +903,11 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable,
Options.AssignToJob = &m_JobObject;
}
#else
- if (m_Pgid > 0)
+ if (m_Pgid == 0)
+ {
+ Options.Flags |= CreateProcOptions::Flag_NewProcessGroup;
+ }
+ else
{
Options.ProcessGroupId = m_Pgid;
}
@@ -917,6 +929,16 @@ ProcessGroup::Impl::Spawn(const std::filesystem::path& Executable,
}
#endif
+ // Install callbacks before starting async readers so no data is missed.
+ if (OnStdout)
+ {
+ ImplPtr->m_StdoutCallback = std::move(OnStdout);
+ }
+ if (OnStderr)
+ {
+ ImplPtr->m_StderrCallback = std::move(OnStderr);
+ }
+
auto Proc = std::unique_ptr<ManagedProcess>(new ManagedProcess(std::move(ImplPtr)));
ManagedProcess* Ptr = AddProcess(std::move(Proc));
@@ -1077,10 +1099,12 @@ ManagedProcess*
ProcessGroup::Spawn(const std::filesystem::path& Executable,
std::string_view CommandLine,
CreateProcOptions& Options,
- ProcessExitCallback OnExit)
+ ProcessExitCallback OnExit,
+ ProcessDataCallback OnStdout,
+ ProcessDataCallback OnStderr)
{
ZEN_TRACE_CPU("ProcessGroup::Spawn");
- return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit));
+ return m_Impl->Spawn(Executable, CommandLine, Options, std::move(OnExit), std::move(OnStdout), std::move(OnStderr));
}
ManagedProcess*
@@ -1185,7 +1209,17 @@ TEST_CASE("SubprocessManager.SpawnAndDetectExit")
CallbackFired = true;
});
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (CallbackFired)
+ {
+ break;
+ }
+ }
+ }
CHECK(CallbackFired);
CHECK(ReceivedExitCode == 42);
@@ -1210,7 +1244,17 @@ TEST_CASE("SubprocessManager.SpawnAndDetectCleanExit")
CallbackFired = true;
});
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (CallbackFired)
+ {
+ break;
+ }
+ }
+ }
CHECK(CallbackFired);
CHECK(ReceivedExitCode == 0);
@@ -1235,7 +1279,17 @@ TEST_CASE("SubprocessManager.StdoutCapture")
ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; });
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (Exited)
+ {
+ break;
+ }
+ }
+ }
CHECK(Exited);
std::string Captured = Proc->GetCapturedStdout();
@@ -1264,7 +1318,17 @@ TEST_CASE("SubprocessManager.StderrCapture")
ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; });
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (Exited)
+ {
+ break;
+ }
+ }
+ }
CHECK(Exited);
std::string CapturedErr = Proc->GetCapturedStderr();
@@ -1289,11 +1353,24 @@ TEST_CASE("SubprocessManager.StdoutCallback")
std::string ReceivedData;
bool Exited = false;
- ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; });
+ ManagedProcess* Proc = Manager.Spawn(
+ AppStub,
+ CmdLine,
+ Options,
+ [&](ManagedProcess&, int) { Exited = true; },
+ [&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); });
- Proc->SetStdoutCallback([&](ManagedProcess&, std::string_view Data) { ReceivedData.append(Data); });
-
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (Exited)
+ {
+ break;
+ }
+ }
+ }
CHECK(Exited);
CHECK(ReceivedData.find("callback_test") != std::string::npos);
@@ -1316,8 +1393,18 @@ TEST_CASE("SubprocessManager.MetricsSampling")
ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { Exited = true; });
- // Run for enough time to get metrics samples
- IoContext.run_for(1s);
+ // Poll until metrics are available
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (Proc->GetLatestMetrics().WorkingSetSize > 0)
+ {
+ break;
+ }
+ }
+ }
ProcessMetrics Metrics = Proc->GetLatestMetrics();
CHECK(Metrics.WorkingSetSize > 0);
@@ -1326,7 +1413,17 @@ TEST_CASE("SubprocessManager.MetricsSampling")
CHECK(Snapshot.size() == 1);
// Let it finish
- IoContext.run_for(3s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 10'000)
+ {
+ IoContext.run_for(10ms);
+ if (Exited)
+ {
+ break;
+ }
+ }
+ }
CHECK(Exited);
}
@@ -1350,7 +1447,7 @@ TEST_CASE("SubprocessManager.RemoveWhileRunning")
// Let it start
IoContext.run_for(100ms);
- // Remove without killing — callback should NOT fire after this
+ // Remove without killing - callback should NOT fire after this
Manager.Remove(Pid);
IoContext.run_for(500ms);
@@ -1375,12 +1472,31 @@ TEST_CASE("SubprocessManager.KillAndWaitForExit")
ManagedProcess* Proc = Manager.Spawn(AppStub, CmdLine, Options, [&](ManagedProcess&, int) { CallbackFired = true; });
// Let it start
- IoContext.run_for(200ms);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (Proc->IsRunning())
+ {
+ break;
+ }
+ }
+ }
Proc->Kill();
- IoContext.run_for(2s);
-
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (CallbackFired)
+ {
+ break;
+ }
+ }
+ }
CHECK(CallbackFired);
}
@@ -1401,7 +1517,17 @@ TEST_CASE("SubprocessManager.AdoptProcess")
Manager.Adopt(ProcessHandle(Result), [&](ManagedProcess&, int ExitCode) { ReceivedExitCode = ExitCode; });
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (ReceivedExitCode != -1)
+ {
+ break;
+ }
+ }
+ }
CHECK(ReceivedExitCode == 7);
}
@@ -1424,7 +1550,17 @@ TEST_CASE("SubprocessManager.UserTag")
Proc->SetTag("my-worker-1");
CHECK(Proc->GetTag() == "my-worker-1");
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (!ReceivedTag.empty())
+ {
+ break;
+ }
+ }
+ }
CHECK(ReceivedTag == "my-worker-1");
}
@@ -1454,7 +1590,17 @@ TEST_CASE("ProcessGroup.SpawnAndMembership")
CHECK(Group->GetProcessCount() == 2);
CHECK(Manager.GetProcessCount() == 2);
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (ExitCount == 2)
+ {
+ break;
+ }
+ }
+ }
CHECK(ExitCount == 2);
}
@@ -1504,7 +1650,17 @@ TEST_CASE("ProcessGroup.AggregateMetrics")
Group->Spawn(AppStub, CmdLine, Options, [](ManagedProcess&, int) {});
// Wait for metrics sampling
- IoContext.run_for(1s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (Group->GetAggregateMetrics().TotalWorkingSetSize > 0)
+ {
+ break;
+ }
+ }
+ }
AggregateProcessMetrics GroupAgg = Group->GetAggregateMetrics();
CHECK(GroupAgg.ProcessCount == 2);
@@ -1570,7 +1726,17 @@ TEST_CASE("ProcessGroup.MixedGroupedAndUngrouped")
CHECK(Group->GetProcessCount() == 2);
CHECK(Manager.GetProcessCount() == 3);
- IoContext.run_for(5s);
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ IoContext.run_for(10ms);
+ if (GroupExitCount == 2 && UngroupedExitCode != -1)
+ {
+ break;
+ }
+ }
+ }
CHECK(GroupExitCount == 2);
CHECK(UngroupedExitCode == 0);
@@ -1590,7 +1756,7 @@ TEST_CASE("ProcessGroup.FindGroup")
TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
{
- // Seed for reproducibility — change to explore different orderings
+ // Seed for reproducibility - change to explore different orderings
//
// Note that while this is a stress test, it is still single-threaded
@@ -1619,7 +1785,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
// Phase 1: Spawn multiple groups with varied workloads
// ========================================================================
- ZEN_INFO("StressTest: Phase 1 — spawning initial groups");
+ ZEN_INFO("StressTest: Phase 1 - spawning initial groups");
constexpr int NumInitialGroups = 8;
std::vector<std::string> GroupNames;
@@ -1673,7 +1839,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
// Phase 2: Randomly kill some groups, create replacements, add ungrouped
// ========================================================================
- ZEN_INFO("StressTest: Phase 2 — random group kills and replacements");
+ ZEN_INFO("StressTest: Phase 2 - random group kills and replacements");
constexpr int NumGroupsToKill = 3;
@@ -1738,7 +1904,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
// Phase 3: Rapid spawn/exit churn
// ========================================================================
- ZEN_INFO("StressTest: Phase 3 — rapid spawn/exit churn");
+ ZEN_INFO("StressTest: Phase 3 - rapid spawn/exit churn");
std::atomic<int> ChurnExitCount{0};
int TotalChurnSpawned = 0;
@@ -1762,7 +1928,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
// Brief pump to allow some exits to be processed
IoContext.run_for(200ms);
- // Destroy the group — any still-running processes get killed
+ // Destroy the group - any still-running processes get killed
Manager.DestroyGroup(Name);
}
@@ -1772,7 +1938,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
// Phase 4: Drain and verify
// ========================================================================
- ZEN_INFO("StressTest: Phase 4 — draining remaining processes");
+ ZEN_INFO("StressTest: Phase 4 - draining remaining processes");
// Check metrics were collected before we wind down
AggregateProcessMetrics Agg = Manager.GetAggregateMetrics();
@@ -1803,7 +1969,7 @@ TEST_CASE("SubprocessManager.StressTest" * doctest::skip())
// (exact count is hard to predict due to killed groups, but should be > 0)
CHECK(TotalExitCallbacks.load() > 0);
- ZEN_INFO("StressTest: PASSED — seed={}", Seed);
+ ZEN_INFO("StressTest: PASSED - seed={}", Seed);
}
TEST_SUITE_END();
diff --git a/src/zenutil/progress.cpp b/src/zenutil/progress.cpp
new file mode 100644
index 000000000..a12076dcc
--- /dev/null
+++ b/src/zenutil/progress.cpp
@@ -0,0 +1,99 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/progress.h>
+
+#include <zencore/logging.h>
+
+namespace zen {
+
+class StandardProgressBase;
+
+class StandardProgressBar : public ProgressBase::ProgressBar
+{
+public:
+ StandardProgressBar(StandardProgressBase& Owner, std::string_view InSubTask) : m_Owner(Owner), m_SubTask(InSubTask) {}
+
+ virtual void UpdateState(const State& NewState, bool DoLinebreak) override;
+ virtual void Finish() override;
+
+private:
+ LoggerRef Log();
+ StandardProgressBase& m_Owner;
+ std::string m_SubTask;
+ State m_State;
+};
+
+class StandardProgressBase : public ProgressBase
+{
+public:
+ StandardProgressBase(LoggerRef Log) : m_Log(Log) {}
+
+ virtual void SetLogOperationName(std::string_view Name) override
+ {
+ m_LogOperationName = Name;
+ ZEN_INFO("{}", m_LogOperationName);
+ }
+ virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override
+ {
+ const size_t PercentDone = StepCount > 0u ? (100u * StepIndex) / StepCount : 0u;
+ ZEN_INFO("{}: {}%", m_LogOperationName, PercentDone);
+ }
+ virtual uint32_t GetProgressUpdateDelayMS() const override { return 2000; }
+ virtual std::unique_ptr<ProgressBar> CreateProgressBar(std::string_view InSubTask) override
+ {
+ return std::make_unique<StandardProgressBar>(*this, InSubTask);
+ }
+
+private:
+ friend class StandardProgressBar;
+ LoggerRef m_Log;
+ std::string m_LogOperationName;
+ LoggerRef Log() { return m_Log; }
+};
+
+LoggerRef
+StandardProgressBar::Log()
+{
+ return m_Owner.Log();
+}
+
+void
+StandardProgressBar::UpdateState(const State& NewState, bool DoLinebreak)
+{
+ ZEN_UNUSED(DoLinebreak);
+ const size_t PercentDone =
+ NewState.TotalCount > 0u ? (100u * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount : 0u;
+ std::string Task = NewState.Task;
+ switch (NewState.Status)
+ {
+ case State::EStatus::Aborted:
+ Task = "Aborting";
+ break;
+ case State::EStatus::Paused:
+ Task = "Paused";
+ break;
+ default:
+ break;
+ }
+ ZEN_INFO("{}: {}%{}", Task, PercentDone, NewState.Details.empty() ? "" : fmt::format(" {}", NewState.Details));
+ m_State = NewState;
+}
+void
+StandardProgressBar::Finish()
+{
+ if (m_State.RemainingCount > 0)
+ {
+ State NewState = m_State;
+ NewState.RemainingCount = 0;
+ NewState.Details = "";
+ UpdateState(NewState, /*DoLinebreak*/ true);
+ }
+}
+
+ProgressBase*
+CreateStandardProgress(LoggerRef Log)
+{
+ return new StandardProgressBase(Log);
+}
+
+} // namespace zen
diff --git a/src/zenutil/sessionsclient.cpp b/src/zenutil/sessionsclient.cpp
index c62cc4099..ec9c6177a 100644
--- a/src/zenutil/sessionsclient.cpp
+++ b/src/zenutil/sessionsclient.cpp
@@ -21,7 +21,7 @@ namespace zen {
//////////////////////////////////////////////////////////////////////////
//
-// SessionLogSink — batching log sink that forwards to /sessions/{id}/log
+// SessionLogSink - batching log sink that forwards to /sessions/{id}/log
//
static const char*
@@ -108,7 +108,7 @@ public:
void SetFormatter(std::unique_ptr<logging::Formatter> /*InFormatter*/) override
{
- // No formatting needed — we send raw message text
+ // No formatting needed - we send raw message text
}
private:
@@ -124,6 +124,9 @@ private:
{
if (Msg.Type == BufferedLogEntry::Type::Shutdown)
{
+ // Mark complete so WaitAndDequeue returns false on empty queue
+ m_Queue.CompleteAdding();
+
// Drain remaining log entries
BufferedLogEntry Remaining;
while (m_Queue.WaitAndDequeue(Remaining))
@@ -172,7 +175,7 @@ private:
{
SendBatch(Batch);
}
- // Drain remaining
+ m_Queue.CompleteAdding();
while (m_Queue.WaitAndDequeue(Extra))
{
if (Extra.Type == BufferedLogEntry::Type::Log)
@@ -226,7 +229,7 @@ private:
}
catch (const std::exception&)
{
- // Best-effort — silently discard on failure
+ // Best-effort - silently discard on failure
}
}
diff --git a/src/zenutil/splitconsole/logstreamlistener.cpp b/src/zenutil/splitconsole/logstreamlistener.cpp
index 04718b543..df985a196 100644
--- a/src/zenutil/splitconsole/logstreamlistener.cpp
+++ b/src/zenutil/splitconsole/logstreamlistener.cpp
@@ -17,7 +17,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
//////////////////////////////////////////////////////////////////////////
-// LogStreamSession — reads CbObject-framed messages from a single TCP connection
+// LogStreamSession - reads CbObject-framed messages from a single TCP connection
class LogStreamSession : public RefCounted
{
@@ -34,7 +34,7 @@ private:
[Self](const asio::error_code& Ec, std::size_t BytesRead) {
if (Ec)
{
- return; // connection closed or error — session ends
+ return; // connection closed or error - session ends
}
Self->m_BufferUsed += BytesRead;
Self->ProcessBuffer();
@@ -119,7 +119,7 @@ private:
m_BufferUsed -= Consumed;
}
- // If buffer is full and we can't parse a message, the message is too large — drop connection
+ // If buffer is full and we can't parse a message, the message is too large - drop connection
if (m_BufferUsed == m_ReadBuf.size())
{
ZEN_WARN("LogStreamSession: buffer full with no complete message, dropping connection");
@@ -141,7 +141,7 @@ private:
struct LogStreamListener::Impl
{
- // Owned io_context mode — creates and runs its own thread
+ // Owned io_context mode - creates and runs its own thread
Impl(LogStreamTarget& Target, uint16_t Port)
: m_Target(Target)
, m_OwnedIoContext(std::make_unique<asio::io_context>())
@@ -154,7 +154,7 @@ struct LogStreamListener::Impl
});
}
- // External io_context mode — caller drives the io_context
+ // External io_context mode - caller drives the io_context
Impl(LogStreamTarget& Target, asio::io_context& IoContext, uint16_t Port) : m_Target(Target), m_Acceptor(IoContext)
{
SetupAcceptor(Port);
@@ -312,7 +312,7 @@ namespace {
logging::LogMessage MakeLogMessage(std::string_view Text, logging::LogLevel Level = logging::Info)
{
- static logging::LogPoint Point{{}, Level, {}};
+ static logging::LogPoint Point{0, 0, Level, {}};
Point.Level = Level;
return logging::LogMessage(Point, "test", Text);
}
@@ -367,7 +367,7 @@ TEST_CASE("DroppedMessageDetection")
asio::ip::tcp::socket Socket(IoContext);
Socket.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), Listener.GetPort()));
- // Send seq=0, then seq=5 — the listener should detect a gap of 4
+ // Send seq=0, then seq=5 - the listener should detect a gap of 4
for (uint64_t Seq : {uint64_t(0), uint64_t(5)})
{
CbObjectWriter Writer;
diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp
index 2b27b2d8b..e1ffeeb3e 100644
--- a/src/zenutil/zenserverprocess.cpp
+++ b/src/zenutil/zenserverprocess.cpp
@@ -181,7 +181,7 @@ ZenServerState::Initialize()
ThrowLastError("Could not map view of Zen server state");
}
#else
- int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666);
+ int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666);
if (Fd < 0)
{
// Work around a potential issue if the service user is changed in certain configurations.
@@ -191,7 +191,7 @@ ZenServerState::Initialize()
// shared memory object and retry, we'll be able to get past shm_open() so long as we have
// the appropriate permissions to create the shared memory object.
shm_unlink("/UnrealEngineZen");
- Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, geteuid() == 0 ? 0766 : 0666);
+ Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT, geteuid() == 0 ? 0766 : 0666);
if (Fd < 0)
{
ThrowLastError("Could not open a shared memory object");
@@ -244,7 +244,7 @@ ZenServerState::InitializeReadOnly()
ThrowLastError("Could not map view of Zen server state");
}
#else
- int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666);
+ int Fd = shm_open("/UnrealEngineZen", O_RDONLY, 0666);
if (Fd < 0)
{
return false;
@@ -267,6 +267,8 @@ ZenServerState::InitializeReadOnly()
ZenServerState::ZenServerEntry*
ZenServerState::Lookup(int DesiredListenPort) const
{
+ const uint32_t OurPid = GetCurrentProcessId();
+
for (int i = 0; i < m_MaxEntryCount; ++i)
{
uint16_t EntryPort = m_Data[i].DesiredListenPort;
@@ -274,6 +276,14 @@ ZenServerState::Lookup(int DesiredListenPort) const
{
if (DesiredListenPort == 0 || (EntryPort == DesiredListenPort))
{
+ // If the entry's PID matches our own but we haven't registered yet,
+ // this is a stale entry from a previous process incarnation (e.g. PID 1
+ // reuse after unclean shutdown in k8s). Skip it.
+ if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr)
+ {
+ continue;
+ }
+
std::error_code _;
if (IsProcessRunning(m_Data[i].Pid, _))
{
@@ -289,6 +299,8 @@ ZenServerState::Lookup(int DesiredListenPort) const
ZenServerState::ZenServerEntry*
ZenServerState::LookupByEffectivePort(int Port) const
{
+ const uint32_t OurPid = GetCurrentProcessId();
+
for (int i = 0; i < m_MaxEntryCount; ++i)
{
uint16_t EntryPort = m_Data[i].EffectiveListenPort;
@@ -296,6 +308,11 @@ ZenServerState::LookupByEffectivePort(int Port) const
{
if (EntryPort == Port)
{
+ if (m_Data[i].Pid == OurPid && m_OurEntry == nullptr)
+ {
+ continue;
+ }
+
std::error_code _;
if (IsProcessRunning(m_Data[i].Pid, _))
{
@@ -358,12 +375,26 @@ ZenServerState::Sweep()
ZEN_ASSERT(m_IsReadOnly == false);
+ const uint32_t OurPid = GetCurrentProcessId();
+
for (int i = 0; i < m_MaxEntryCount; ++i)
{
ZenServerEntry& Entry = m_Data[i];
if (Entry.DesiredListenPort)
{
+ // If the entry's PID matches our own but we haven't registered yet,
+ // this is a stale entry from a previous process incarnation (e.g. PID 1
+ // reuse after unclean shutdown in k8s). Reclaim it.
+ if (Entry.Pid == OurPid && m_OurEntry == nullptr)
+ {
+ ZEN_CONSOLE_DEBUG("Sweep - pid {} matches current process but no registration yet, reclaiming stale entry (port {})",
+ Entry.Pid.load(),
+ Entry.DesiredListenPort.load());
+ Entry.Reset();
+ continue;
+ }
+
std::error_code ErrorCode;
if (Entry.Pid != 0 && IsProcessRunning(Entry.Pid, ErrorCode) == false)
{
@@ -620,7 +651,7 @@ ZenServerInstanceInfo::Create(const Oid& SessionId, const InstanceInfoData& Data
ThrowLastError("Could not map instance info shared memory");
}
#else
- int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0666);
+ int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666);
if (Fd < 0)
{
ThrowLastError("Could not create instance info shared memory");
@@ -687,7 +718,7 @@ ZenServerInstanceInfo::OpenReadOnly(const Oid& SessionId)
return false;
}
#else
- int Fd = shm_open(Name.c_str(), O_RDONLY | O_CLOEXEC, 0666);
+ int Fd = shm_open(Name.c_str(), O_RDONLY, 0666);
if (Fd < 0)
{
return false;
@@ -1583,7 +1614,7 @@ ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason)
std::optional<int>
StartupZenServer(LoggerRef LogRef, const StartupZenServerOptions& Options)
{
- auto Log = [&LogRef]() { return LogRef; };
+ ZEN_SCOPED_LOG(LogRef);
// Check if a matching server is already running
{
@@ -1653,7 +1684,7 @@ ShutdownZenServer(LoggerRef LogRef,
ZenServerState::ZenServerEntry* Entry,
const std::filesystem::path& ProgramBaseDir)
{
- auto Log = [&LogRef]() { return LogRef; };
+ ZEN_SCOPED_LOG(LogRef);
int EntryPort = (int)Entry->DesiredListenPort.load();
const uint32_t ServerProcessPid = Entry->Pid.load();
try
diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp
index 2ca380c75..b282adc03 100644
--- a/src/zenutil/zenutil.cpp
+++ b/src/zenutil/zenutil.cpp
@@ -5,9 +5,11 @@
#if ZEN_WITH_TESTS
# include <zenutil/cloud/imdscredentials.h>
+# include <zenutil/consul.h>
# include <zenutil/cloud/s3client.h>
# include <zenutil/cloud/sigv4.h>
# include <zenutil/config/commandlineoptions.h>
+# include <zenutil/filesystemutils.h>
# include <zenutil/rpcrecording.h>
# include <zenutil/splitconsole/logstreamlistener.h>
# include <zenutil/process/subprocessmanager.h>
@@ -20,6 +22,8 @@ zenutil_forcelinktests()
{
cache::rpcrecord_forcelink();
commandlineoptions_forcelink();
+ consul::consul_forcelink();
+ filesystemutils_forcelink();
imdscredentials_forcelink();
logstreamlistener_forcelink();
subprocessmanager_forcelink();